import numpy as np
from torch.utils.data import Dataset
import torch
from .base import *


class ZJUMocapDataset(Dataset):
    def __init__(self, dataset_path: str, view: str = "CoreView_313", frame_id: int = -1) -> None:
        super().__init__()

        self.dataset_path = dataset_path
        self.view = view

        self.frames_num = get_frames_num(dataset_path, view)
        self.cam_list = get_cam_list(view)
        self.cam_num = len(self.cam_list)

        if frame_id >= self.frames_num:
            print(f"frames_num={self.frames_num}")
            assert False

        self.frame_id = frame_id

    def __len__(self):
        if self.frame_id >= 0:
            return self.cam_num
        return self.frames_num * self.cam_num

    def getitem(self, index: int) -> dict:
        if self.frame_id >= 0:
            frame_id = self.frame_id
        else:
            frame_id = index // self.cam_num
        cam_id = index % self.cam_num
        cam = self.cam_list[cam_id]

        d = dict()
        for key in ["new_vertices", "img"]:
            d.update(get_frame_data(self.dataset_path, self.view, cam, frame_id, key))
        d.update(get_camera_params(self.dataset_path, self.view, cam))
        return d

    def __getitem__(self, index: int) -> dict:
        d = self.getitem(index)
        for key in d.keys():
            d[key] = torch.from_numpy(d[key]).float()
        d["img"] = d["img"] / 255.0
        return d
