from typing import Dict, List

import cv2
from torch.utils.data import Dataset

from ...tools.logger import Logger


class VideoDataset(Dataset):
    """video dataset"""

    def __init__(self, video_filename: str) -> None:
        super().__init__()
        self.video_filename = video_filename
        Logger.exists(video_filename)

        self.cap = cv2.VideoCapture(video_filename)
        self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        self.fps = self.cap.get(cv2.CAP_PROP_FPS)
        self._curr_frame = -1

    def info(self) -> str:
        return "video"

    def keys(self) -> List[str]:
        return ["img"]

    def meta(self) -> dict:
        return {
            "frame_width": self.frame_width,
            "frame_height": self.frame_height,
            "frame_count": self.frame_count,
            "fps": self.fps,
            "fname": self.video_filename,
        }

    def __len__(self):
        return self.frame_count

    def __del__(self):
        self.cap.release()

    def __getitem__(self, index) -> dict:

        if self._curr_frame + 1 != index:
            self.cap.set(cv2.CAP_PROP_POS_FRAMES, index)
        opened, frame = self.cap.read()
        self._curr_frame = index
        if not opened:
            Logger.error(f"[can not open]: {self.video_filename}")
            return

        return {
            "img": frame,
        }

    def render(self, index):
        return self.__getitem__(index)

    def close(self):
        self.cap.release()
