# 3DPW dataset
# https://virtualhumans.mpi-inf.mpg.de/3DPW/

# @inproceedings{vonMarcard2018,
#     title = {Recovering Accurate 3D Human Pose in The Wild Using IMUs and a Moving Camera},
#     author = {von Marcard, Timo and Henschel, Roberto and Black, Michael and Rosenhahn, Bodo and Pons-Moll, Gerard},
#     booktitle = {European Conference on Computer Vision (ECCV)},
#     year = {2018},
#     month = {sep}
#     }

import os
import os.path as osp
from typing import List

from ...files.base import load_pkl


def get_mode_list() -> List[str]:
    return ["train", "test", "validation"]


def get_sequence_list(dataset_path: str, mode: str) -> List[str]:
    sequence_list = os.listdir(osp.join(dataset_path, "sequenceFiles", mode))
    sequence_list = [a.split(".pkl")[0] for a in sequence_list]
    sequence_list.sort()
    return sequence_list


def get_sequence_info(dataset_path: str, mode: str, sequence: str, pid: int = None) -> List[dict]:
    """get_sequence_info

    Args:
        dataset_path (str): dataset_path
        mode (str): "train", "test", "validation"
        sequence (str): _description_
        pid (int, optional): 指定人的编号. Defaults to None.

    Returns:
        List[dict]: _description_
    """
    d = load_pkl(osp.join(dataset_path, "sequenceFiles", mode, sequence + ".pkl"))

    # keys = ['betas', 'betas_clothed', 'cam_intrinsics', 'cam_poses', 'campose_valid', 'genders', 'img_frame_ids', 'jointPositions', 'poses', 'poses2d', 'poses_60Hz', 'sequence', 'texture_maps', 'trans', 'trans_60Hz', 'v_template_clothed']

    people_num = len(d["genders"])
    frames_num = d["cam_poses"].shape[0]
    assert d["sequence"] == sequence

    if pid is None:
        pid_list = range(people_num)
    else:
        assert 0 <= pid < people_num
        pid_list = [
            pid,
        ]

    datalist = []
    for pid in pid_list:
        # (6890, 3)
        v_template_clothed = d["v_template_clothed"][pid]

        for fid in range(frames_num):
            # filter bad pose
            if d["campose_valid"][pid][fid] != 1:
                continue

            datalist.append(
                {
                    # imgname
                    "imgname": f"imageFiles/{sequence}/image_{fid:05d}.jpg",
                    "sequence": sequence,
                    # camera
                    "K": d["cam_intrinsics"],  # (3, 3)
                    "w2c": d["cam_poses"][fid],  # (4, 4)
                    # smpl
                    "gender": d["genders"][pid],  # 'm' or 'f'
                    "betas": d["betas"][pid][:10],
                    "betas_clothed": d["betas_clothed"][pid][:10],
                    "pose": d["poses"][pid][fid],
                    "trans": d["trans"][pid][fid],
                    # joint
                    "jointPositions": d["jointPositions"][pid][fid].reshape([24, 3]),
                    "poses2d": d["poses2d"][pid][fid].T.reshape([18, 3]),
                }
            )
    return datalist


def get_mode_info(dataset_path: str, mode: str) -> List[dict]:
    datalist = []
    for sequence in get_sequence_list(dataset_path, mode):
        datalist.extend(get_sequence_info(dataset_path, mode, sequence))
    return datalist
