# ZJU-Mocap dataset
# "web": "https://github.com/zju3dv/neuralbody/blob/master/INSTALL.md#zju-mocap-dataset",
# "download_url": "https://zjueducn-my.sharepoint.com/:f:/g/personal/3180105504_zju_edu_cn/EvaarKZmSO5FsvG8uEfNZo4BPbHsCYq7q7QpUQci3dnRPA?e=4U2LFJ"

import os
import os.path as osp
from typing import List
from ...files import load_npy, load_json
import numpy as np
import cv2


def get_view_list() -> List[str]:
    return [f"CoreView_{view}" for view in ["313", "315", "377", "386", "387", "390", "392", "393", "394"]]


def get_cam_list(view: str) -> List[str]:
    if view == "CoreView_313":
        return [f"Camera ({i})" for i in range(1, 24) if i not in [20, 21]]
    if view in ["CoreView_377", "CoreView_386"]:
        return [f"Camera_B{i}" for i in range(1, 24)]
    assert False
    # TODO


_CAMERA_PARAMS = None
_FILENAME_DICT = dict()


def get_camera_params(dataset_path: str, view: str, cam: str) -> dict:
    global _CAMERA_PARAMS
    if _CAMERA_PARAMS is None:
        print("get_camera_params")
        d = load_npy(osp.join(dataset_path, view, "annots.npy")).item()
        _CAMERA_PARAMS = d["cams"]

    idx = get_cam_list(view).index(cam)
    R = np.array(_CAMERA_PARAMS["R"][idx], dtype=np.float32)  # (3, 3)
    T = np.array(_CAMERA_PARAMS["T"][idx], dtype=np.float32) / 1000.0  # (3, 1)
    w2c = np.concatenate([R, T], axis=1)  # (3, 4)
    return {
        "K": np.array(_CAMERA_PARAMS["K"][idx], dtype=np.float32),  # (3, 3)
        "D": np.array(_CAMERA_PARAMS["D"][idx], dtype=np.float32),  # (5, 1)
        "w2c": w2c,  # (3, 4)
    }


def get_frames_num(dataset_path: str, view: str) -> int:
    cam = get_cam_list(view)[0]
    img_path = osp.join(dataset_path, view, cam)
    return len(os.listdir(img_path))


def get_filename(dataset_path: str, view: str, cam: str, frame_id: int, key: str) -> dict:
    global _FILENAME_DICT
    name = view + "." + cam + "." + key
    if name not in _FILENAME_DICT:
        filename_list = None
        if key == "img":
            path = osp.join(dataset_path, view, cam)
        elif key in ["mask", "mask_cihp"]:
            path = osp.join(dataset_path, view, key, cam)
        elif key == "openpose":
            path = osp.join(dataset_path, view, "keypoints2d", cam)
        elif key in ["params", "new_params", "vertices", "new_vertices"]:
            path = osp.join(dataset_path, view, key)
            filename_list = [int(a.split(".npy")[0]) for a in os.listdir(path)]
            filename_list.sort()
            filename_list = [f"{a}.npy" for a in filename_list]
        elif key == "bweights":
            path = osp.join(dataset_path, view, "lbs", "bweights")
            filename_list = [int(a.split(".npy")[0]) for a in os.listdir(path)]
            filename_list.sort()
            filename_list = [f"{a}.npy" for a in filename_list]
        else:
            assert False

        if filename_list is None:
            filename_list = os.listdir(path)
            filename_list.sort()
        _FILENAME_DICT[name + ".path"] = path
        _FILENAME_DICT[name] = filename_list

    path = _FILENAME_DICT[name + ".path"]
    filename_list = _FILENAME_DICT[name]
    return osp.join(path, filename_list[frame_id])


def get_frame_data(dataset_path: str, view: str, cam: str, frame_id: int, key: str) -> dict:
    filename = osp.join(dataset_path, get_filename(dataset_path, view, cam, frame_id, key))
    # print(cam, frame_id, key, filename)
    if key == "img":
        # shape: (1024, 1024, 3)
        d = {key: cv2.imread(filename)}

    elif key == "mask":
        # shape: (1024, 1024, 1)
        # 0 or 1
        d = {key: cv2.imread(filename).astype(np.float32)[:, :, 0:1]}

    elif key == "mask_cihp":
        # shape: (1024, 1024, 1)
        # 0 - 19
        d = {key: cv2.imread(filename).astype(np.float32)[:, :, 0:1]}

    elif key in ["params", "new_params"]:
        old_d: dict = load_npy(filename).item()
        d = {
            "Rh": old_d["Rh"][0],  # (3, )
            "Th": old_d["Th"][0],  # (3, )
            "pose": old_d["poses"][0],  # (72, )
            "betas": old_d["shapes"][0],  # (10, )
        }

    elif key in ["vertices", "new_vertices"]:
        # shape: (6890, 3)
        d = {key: load_npy(filename)}

    elif key == "bweights":
        # shape: (37, 73, 24, 25)
        d = {key: load_npy(filename)}

    elif key == "openpose":
        # shape: (25, 3)
        people = load_json(filename)["people"]
        assert len(people) == 1
        openpose = np.array(people[0]["pose_keypoints_2d"], dtype=np.float32).reshape([25, 3])
        d = {key: openpose}

    else:
        assert False
    return d
