import torchvision.transforms as transforms
import torch

IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD = [0.229, 0.224, 0.225]


normalize_img = transforms.Normalize(mean=IMG_MEAN, std=IMG_STD)


def un_normalize_img(img: torch.FloatTensor) -> torch.FloatTensor:
    return img * torch.FloatTensor(IMG_STD).unsqueeze(-1).unsqueeze(-1) + torch.FloatTensor(IMG_MEAN).unsqueeze(-1).unsqueeze(
        -1
    )
