import os.path as osp
from typing import List

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset

from ...files.obj_file import save_obj
from .base import get_mode_info
from ...models.smpl.smpl import SMPL


class PW3DRender(Dataset):
    def __init__(self, dataset_path: str, mode: str) -> None:
        super().__init__()
        self.dataset_path = dataset_path
        self.datalist = get_mode_info(dataset_path, mode)

        self.smpl = SMPL("/home/lwk/smpl")  # .cuda()

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, index):
        d = self.datalist[index]

    def render(self, index):
        d = self.datalist[index]
        img = cv2.imread(osp.join(self.dataset_path, d["imgname"]))

        w2c = d["w2c"]
        trans = d["trans"]

        betas = torch.from_numpy(d["betas"]).float().unsqueeze(0)
        pose = torch.from_numpy(d["pose"]).float().unsqueeze(0)

        if d["gender"] == "m":
            gender = torch.LongTensor(
                [
                    1,
                ]
            )
        elif d["gender"] == "f":
            gender = torch.LongTensor(
                [
                    2,
                ]
            )
        else:
            assert False

        vertices = self.smpl.forward(betas, pose, gender)[0].numpy()
        faces = self.smpl.faces
        # save_obj("1.obj", vertices, faces)

        v = w2c[:3, :3].dot(vertices.T + trans[:, None]).T + w2c[:3, 3][None]
        K = d["K"]

        v2d = v / v[:, -1:]
        v2d = K.dot(v2d.T).T[:, :2] + 0.5
        v2d = v2d.astype(np.int32)
        H, W = img.shape[:2]
        for i in range(v2d.shape[0]):
            u, v = v2d[i]
            if 0 <= u < W and 0 <= v < H:
                img[v, u, 0] = 255
                img[v, u, 1] = 0
                img[v, u, 1] = 0

        return img
