import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import cv2
from PIL import Image
import numpy as np
import random
import math
import numbers
import collections
from torch.autograd import Variable
try:
    import accimage
except ImportError:
    accimage = None
from vhh_sbd.Configuration import Configuration
from vhh_sbd.PreProcessing import PreProcessing


class deepSBD(nn.Module):
    """
    This class is represents the pytorch model architecture used for sbd candidate selection.
    It detects frames ranges of about 16 frames which includes an abrupt cut. The loaded model
    is pre-trained on the deepsbd dataset.
    """

    def __init__(self):
        """
        Constructor.
        """
        super(deepSBD, self).__init__()
        self.conv1 = nn.Conv3d(3, 96, kernel_size=3, stride=(1, 2, 2), padding=(0, 0, 0), bias=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=0)
        self.conv2 = nn.Conv3d(96, 256, kernel_size=3, stride=(1, 2, 2), padding=(0, 0, 0), bias=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=0)
        self.conv3 = nn.Conv3d(256, 384, kernel_size=3, stride=1, padding=1, bias=True)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv3d(384, 384, kernel_size=3, stride=1, padding=1, bias=True)
        self.relu4 = nn.ReLU(inplace=True)
        self.conv5 = nn.Conv3d(384, 256, kernel_size=3, stride=1, padding=1, bias=True)
        self.relu5 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=0)
        self.fc6 = nn.Linear(100352, 2048)
        self.relu6 = nn.ReLU(inplace=True)
        self.fc7 = nn.Linear(2048, 2048)
        self.relu7 = nn.ReLU(inplace=True)
        self.fc8 = nn.Linear(2048, 3)
    
    def forward(self, x):
        """
        This method is needed to calculate the foward pass of the model.

        :param x: this parameter must be a valid pytorch tensor.
        :return:  This method returns a pytorch tensor with the specified shape.
        """
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.conv5(x)
        x = self.relu5(x)
        x = x.view(x.size(0), -1)
        x = self.fc6(x)
        x = self.relu6(x)
        x = self.fc7(x)
        x = self.relu7(x)
        x = self.fc8(x)
        return x


class CandidateSelection(object):
    """
    This class is used for sbd candidate selection. It detects frames ranges of about 16 frames which includes an
    abrupt cut. The loaded model is pre-trained on the deepsbd dataset.
    """

    def __init__(self, config_instance: Configuration):
        """
        Constructor.

        :param config_instance: object instance of type Configuration
        """
        print("create instance of candidate selection module ... ")
        self.config_instance = config_instance;
        self.preprocessing = PreProcessing(self.config_instance)

    def run(self, video_path):
        """
        This method is used to run the candidate selection process.

        :param video_path: This parameter must hold a valid path to a video file.
        :return: This method returns a numpy array with a list of all detected frames ranges.
        """

        model_path = self.config_instance.pretrained_model
        temporal_length = 16
        batch_size = 32

        num_classes = 2
        self.model = deepSBD()  # deepsbd class
        self.model = self.model.cuda()
        self.model = nn.DataParallel(self.model)

        # /caa/Projects02/vhh/private/database_nobackup/public_datasets/Clipshots/ClipShots/
        # --model alexnet
        # --weights /caa/Homes01/dhelm/working/pycharm_sbd_clipshot/ClipShots_basline/pretrained/Alexnet-final.pth
        # --result_dir /caa/Homes01/dhelm/results/results_sbd/results_sbd_clipshots /
        # --test_list_path /caa/Projects02/vhh/private/database_nobackup/public_datasets/Clipshots/ClipShots/video_lists/test.txt

        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['state_dict'])

        self.spatial_transform = Compose([Scale((128, 128)),
                                          ToTensor(1),
                                          Normalize(self.get_mean(1), [1, 1, 1])])
        self.model.eval()
        videocap = cv2.VideoCapture(video_path)
        status = True
        clip_batch = []
        labels = []
        image_clip = []
        while status:
            for i in range(temporal_length - len(image_clip)):
                status, frame = videocap.read()
                #print(i)
                if not status:
                    break
                else:
                    # apply preprecessing
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB);
                    frame = self.preprocessing.applyTransformOnImg(frame)

                    frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).convert('RGB')
                    frame = self.spatial_transform(frame)
                    image_clip.append(frame)

            image_clip += [image_clip[-1] for _ in range(temporal_length - len(image_clip))]

            if len(image_clip) == temporal_length:
                clip = torch.stack(image_clip, 0).permute(1, 0, 2, 3)
                clip_batch.append(clip)
                image_clip = image_clip[int(temporal_length / 2):]

            if len(clip_batch) == batch_size or not status:
                clip_tensor = torch.stack(clip_batch, 0)
                clip_tensor = Variable(clip_tensor).cuda()
                results = self.model(clip_tensor)
                labels += self.get_label(results)
                clip_batch = []

        final_res = []
        i = 0
        while i < len(labels):
            if labels[i] > 0:
                label = labels[i]
                begin = i
                i += 1
                while i < len(labels) and labels[i] == labels[i - 1]:
                    i += 1
                end = i - 1
                final_res.append((begin * temporal_length / 2 + 1, end * temporal_length / 2 + 16 + 1, label))
            else:
                i += 1

        # prepare results
        abrupt_l = []
        graduals_l = []
        for begin, end, label in final_res:
            if label == 2:
                abrupt_l.append((begin, end))
            else:
                graduals_l.append((begin, end))

        #print(final_res)
        #print(abrupt_l)
        #print(graduals_l)

        final_res_np = np.array(final_res)
        abrupt_np = np.array(abrupt_l).astype('int')
        graduals_np = np.array(graduals_l)
        return abrupt_np

    def get_mean(self, norm_value=255):
        """
        Helper method to calculate the normalized mean values.

        :param norm_value: Base for normalization (default: 255).
        :return: array with the normalized mean values for each color channel (RGB)
        """
        return [114.7748 / norm_value, 107.7354 / norm_value, 99.4750 / norm_value]

    def get_test_spatial_transform(self, opt):
        return Compose([Scale((opt.spatial_size, opt.spatial_size)),
                        ToTensor(opt.norm_value),
                        Normalize(self.get_mean(opt.norm_value), [1, 1, 1])])

    def get_label(self, res_tensor):
        res_numpy = res_tensor.data.cpu().numpy()
        labels = []
        for row in res_numpy:
            labels.append(np.argmax(row))
        return labels



class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def randomize_parameters(self):
        for t in self.transforms:
            t.randomize_parameters()


class ToTensor(object):
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __init__(self, norm_value=255):
        self.norm_value = norm_value

    def __call__(self, pic):
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
            # backward compatibility
            return img.float().div(self.norm_value)

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

        # handle PIL Image
        if pic.mode == 'I':
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
        elif pic.mode == 'I;16':
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(self.norm_value)
        else:
            return img

    def randomize_parameters(self):
        pass


class Normalize(object):
    """Normalize an tensor image with mean and standard deviation.
    Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        # TODO: make efficient
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor

    def randomize_parameters(self):
        pass


class Scale(object):
    """Rescale the input PIL.Image to the given size.
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be scaled.
        Returns:
            PIL.Image: Rescaled image.
        """
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
        else:
            return img.resize(self.size, self.interpolation)

    def randomize_parameters(self):
        pass


class CenterCrop(object):
    """Crops the given PIL.Image at the center.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be cropped.
        Returns:
            PIL.Image: Cropped image.
        """
        w, h = img.size
        th, tw = self.size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return img.crop((x1, y1, x1 + tw, y1 + th))

    def randomize_parameters(self):
        pass


class CornerCrop(object):
    def __init__(self, size, crop_position=None):
        self.size = size
        if crop_position is None:
            self.randomize = True
        else:
            self.randomize = False
        self.crop_position = crop_position
        self.crop_positions = ['c', 'tl', 'tr', 'bl', 'br']

    def __call__(self, img):
        image_width = img.size[0]
        image_height = img.size[1]

        if self.crop_position == 'c':
            th, tw = (self.size, self.size)
            x1 = int(round((image_width - tw) / 2.))
            y1 = int(round((image_height - th) / 2.))
            x2 = x1 + tw
            y2 = y1 + th
        elif self.crop_position == 'tl':
            x1 = 0
            y1 = 0
            x2 = self.size
            y2 = self.size
        elif self.crop_position == 'tr':
            x1 = image_width - self.size
            y1 = 0
            x2 = image_width
            y2 = self.size
        elif self.crop_position == 'bl':
            x1 = 0
            y1 = image_height - self.size
            x2 = self.size
            y2 = image_height
        elif self.crop_position == 'br':
            x1 = image_width - self.size
            y1 = image_height - self.size
            x2 = image_width
            y2 = image_height

        img = img.crop((x1, y1, x2, y2))

        return img

    def randomize_parameters(self):
        if self.randomize:
            self.crop_position = self.crop_positions[
                random.randint(0, len(self.crop_positions) - 1)]


class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be flipped.
        Returns:
            PIL.Image: Randomly flipped image.
        """
        if self.p < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img

    def randomize_parameters(self):
        self.p = random.random()


class MultiScaleCornerCrop(object):
    """Crop the given PIL.Image to randomly selected size.
    A crop of size is selected from scales of the original size.
    A position of cropping is randomly selected from 4 corners and 1 center.
    This crop is finally resized to given size.
    Args:
        scales: cropping scales of the original size
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, scales, size, interpolation=Image.BILINEAR):
        self.scales = scales
        self.size = size
        self.interpolation = interpolation

        self.crop_positions = ['c', 'tl', 'tr', 'bl', 'br']

    def __call__(self, img):
        min_length = min(img.size[0], img.size[1])
        crop_size = int(min_length * self.scale)

        image_width = img.size[0]
        image_height = img.size[1]

        if self.crop_position == 'c':
            center_x = image_width // 2
            center_y = image_height // 2
            box_half = crop_size // 2
            x1 = center_x - box_half
            y1 = center_y - box_half
            x2 = center_x + box_half
            y2 = center_y + box_half
        elif self.crop_position == 'tl':
            x1 = 0
            y1 = 0
            x2 = crop_size
            y2 = crop_size
        elif self.crop_position == 'tr':
            x1 = image_width - crop_size
            y1 = 0
            x2 = image_width
            y2 = crop_size
        elif self.crop_position == 'bl':
            x1 = 0
            y1 = image_height - crop_size
            x2 = crop_size
            y2 = image_height
        elif self.crop_position == 'br':
            x1 = image_width - crop_size
            y1 = image_height - crop_size
            x2 = image_width
            y2 = image_height

        img = img.crop((x1, y1, x2, y2))

        return img.resize((self.size, self.size), self.interpolation)

    def randomize_parameters(self):
        self.scale = self.scales[random.randint(0, len(self.scales) - 1)]
        self.crop_position = self.crop_positions[random.randint(0, len(self.scales) - 1)]




'''

# spatial_transforms = get_test_spatial_transform(opt)
spatial_transforms = Compose([Scale((128, 128)),
                    ToTensor(1),
                    Normalize(get_mean(1), [1, 1, 1])])
					
labels = deepSBD(os.path.join(opt.root_dir, opt.test_subdir, videoname), 32, model, spatial_transforms, 128)

'''

