from typing import List

from torch.utils.data import Dataset
from ..base_dataset import BaseDataset
from ...tools.logger import Logger


class MixedDataset(Dataset):
    """mixed dataset"""

    def __init__(self, dataset_list: List[BaseDataset]):
        self.dataset_list = dataset_list
        self.lens = [len(dataset) for dataset in dataset_list]
        total = sum(self.lens)
        Logger.info(f"[mixed dataset]: total={total}, lens: {self.lens}")

    def info(self) -> str:
        info_list = [x.info() for x in self.dataset_list]
        return "Mixed:" + str(info_list)

    def keys(self) -> List[str]:
        return self.dataset_list[0].keys()

    def meta(self, index) -> dict:
        i, k = self._get_real_idx(index)
        return self.dataset_list[i].meta(k)

    def _get_real_idx(self, index):
        k = index
        for i, len in enumerate(self.lens):
            if k < len:
                return i, k
            k -= len

    def __getitem__(self, index):
        i, k = self._get_real_idx(index)
        return self.dataset_list[i][k]

    def __len__(self):
        return sum(self.lens)
