import numpy as np

from .image_ops import crop, flip_img, flip_pose, flip_kp, transform, rot_aa


class Augmentation:
    def __init__(
        self, is_train: bool, img_res: int = 224, scale_factor: float = 0.25, rot_factor: float = 30, noise_factor: float = 0.4
    ):
        self.is_train = is_train
        self.img_res = img_res
        self.scale_factor = scale_factor  # rescale bounding boxes by a factor of [1-scale_factor,1+scale_factor]
        self.noise_factor = noise_factor
        self.rot_factor = rot_factor  # Random rotation in the range [-rot_factor, rot_factor]

    def do_augmentation(self, annotations: dict) -> dict:
        flip, pn, rot, sc = self.augm_params()
        center = annotations["center"]
        scale = annotations["scale"]

        if "img" in annotations:
            annotations["img"] = self.rgb_processing(annotations["img"], center, sc * scale, rot, flip, pn)

        annotations["joints_3d"] = self.j3d_processing(annotations["joints_3d"], rot, flip)
        annotations["joints_2d"] = self.j2d_processing(annotations["joints_2d"], center, sc * scale, rot, flip)
        annotations["pose"] = self.pose_processing(annotations["pose"], rot, flip)

        if "hrnet" in annotations:
            annotations["hrnet"] = self.j2d_processing(annotations["hrnet"], center, sc * scale, rot, flip)
        return annotations

    def augm_params(self):
        """Get augmentation parameters."""
        flip = 0  # flipping
        pn = np.ones(3)  # per channel pixel-noise
        rot = 0  # rotation
        sc = 1  # scaling
        if self.is_train:
            # We flip with probability 1/2
            if np.random.uniform() <= 0.5:
                flip = 1

            # Each channel is multiplied with a number
            # in the area [1-opt.noiseFactor,1+opt.noiseFactor]
            pn = np.random.uniform(1 - self.noise_factor, 1 + self.noise_factor, 3)

            # The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
            rot = min(2 * self.rot_factor, max(-2 * self.rot_factor, np.random.randn() * self.rot_factor))

            # The scale is multiplied with a number
            # in the area [1-scaleFactor,1+scaleFactor]
            sc = min(1 + self.scale_factor, max(1 - self.scale_factor, np.random.randn() * self.scale_factor + 1))
            # but it is zero with probability 3/5
            if np.random.uniform() <= 0.6:
                rot = 0

        return flip, pn, rot, sc

    def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
        """Process rgb image and do augmentation."""
        rgb_img = crop(rgb_img, center, scale, [self.img_res, self.img_res], rot=rot)
        # flip the image
        if flip:
            rgb_img = flip_img(rgb_img)
        # in the rgb image we add pixel noise in a channel-wise manner
        rgb_img[:, :, 0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 0] * pn[0]))
        rgb_img[:, :, 1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 1] * pn[1]))
        rgb_img[:, :, 2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 2] * pn[2]))
        # (3,224,224),float,[0,1]
        rgb_img = np.transpose(rgb_img.astype("float32"), (2, 0, 1)) / 255.0
        return rgb_img

    def j2d_processing(self, kp, center, scale, r, f):
        """Process gt 2D keypoints and apply all augmentation transforms."""
        nparts = kp.shape[0]
        for i in range(nparts):
            kp[i, 0:2] = transform(kp[i, 0:2] + 1, center, scale, [self.img_res, self.img_res], rot=r)
        # convert to normalized coordinates
        kp[:, :-1] = 2.0 * kp[:, :-1] / self.img_res - 1.0
        # flip the x coordinates
        if f:
            kp = flip_kp(kp)
        kp = kp.astype("float32")
        return kp

    def j3d_processing(self, S, r, f):
        """Process gt 3D keypoints and apply all augmentation transforms."""
        # in-plane rotation
        rot_mat = np.eye(3)
        if not r == 0:
            rot_rad = -r * np.pi / 180
            sn, cs = np.sin(rot_rad), np.cos(rot_rad)
            rot_mat[0, :2] = [cs, -sn]
            rot_mat[1, :2] = [sn, cs]
        S[:, :-1] = np.einsum("ij,kj->ki", rot_mat, S[:, :-1])
        # flip the x coordinates
        if f:
            S = flip_kp(S)
        S = S.astype("float32")
        return S

    def pose_processing(self, pose, r, f):
        """Process SMPL theta parameters  and apply all augmentation transforms."""
        # rotation or the pose parameters
        pose = pose.astype("float32")
        pose[:3] = rot_aa(pose[:3], r)
        # flip the pose parameters
        if f:
            pose = flip_pose(pose)
        # (72),float
        pose = pose.astype("float32")
        return pose
