import numpy as np
from torch.utils.data import Dataset

from ...visualization.plot_joints import plot_kps_2d
from .base import *


class ZJUMocapRender(Dataset):
    def __init__(self, dataset_path: str, view: str = "CoreView_313") -> None:
        super().__init__()

        self.dataset_path = dataset_path
        self.view = view

        self.frames_num = get_frames_num(dataset_path, view)
        self.cam_list = get_cam_list(view)
        self.cam_num = len(self.cam_list)

    def __len__(self):
        return self.frames_num * self.cam_num

    def getitem(self, index: int) -> dict:
        frame_id = index // self.cam_num
        cam_id = index % self.cam_num
        cam = self.cam_list[cam_id]

        d = dict()
        for key in ["new_vertices", "vertices", "img", "openpose"]:
            d.update(get_frame_data(self.dataset_path, self.view, cam, frame_id, key))
        d.update(get_camera_params(self.dataset_path, self.view, cam))
        return d

    def render(self, index: int) -> np.ndarray:
        d = self.getitem(index)

        # v = d["vertices"]
        v = d["new_vertices"]
        w2c = d["w2c"]
        K = d["K"]
        v = (w2c[:3, :3].dot(v.T) + w2c[:3, -1:]).T
        v = v / v[:, -1:]
        uv = K.dot(v.T).T[:, :2] + 0.5
        uv = uv.astype(np.int32)
        img = d["img"]
        H, W = img.shape[:2]
        for i in range(uv.shape[0]):
            u, v = uv[i]
            if 0 <= u < W and 0 <= v < H:
                img[v, u, 0] = 255
                img[v, u, 1] = 0
                img[v, u, 2] = 0

        openpose = d["openpose"]
        plot_kps_2d(img, openpose, color=(0, 0, 255))
        return img
