import glob
import os.path as osp
from typing import List

import cv2

from ...tools.logger import Logger
from torch.utils.data import Dataset


class ImageFolderDataset(Dataset):
    """image folder dataset"""

    def __init__(self, image_folder: str) -> None:

        Logger.exists(image_folder)

        self.imgpaths = []
        for fmt in ["*.jpg", "*.jpeg", "*.png"]:
            self.imgpaths.extend(glob.glob(osp.join(image_folder, fmt)))
        self.imgpaths.sort()

    def info(self) -> str:
        return "image folder"

    def keys(self) -> List[str]:
        return ["imgpath", "img"]

    def meta(self, index) -> dict:
        return {"imgpath": self.imgpaths[index]}

    def __len__(self):
        return len(self.imgpaths)

    def __getitem__(self, index):
        imgpath = self.imgpaths[index]
        img = cv2.imread(imgpath)
        return {
            "imgpath": imgpath,
            "img": img,
        }
