from typing import List

import numpy as np
from torch.utils.data import Dataset

from ..base_dataset import BaseDataset


class RandomSplitDataset(Dataset):
    """Randomly divide parts of the dataset"""

    def __init__(self, dataset: BaseDataset, random_split: float = 1.0) -> None:
        super().__init__()

        self.dataset = dataset
        self.random_split = random_split

        N = len(self.dataset)
        n = int(N * float(random_split))
        self.__index_list = np.random.permutation(N)[:n].tolist()

    def info(self) -> str:
        return self.dataset.info()

    def meta(self, index: int) -> dict:
        return self.dataset.meta(index)

    def keys(self) -> List[str]:
        return self.dataset.keys()

    def __len__(self) -> int:
        return len(self.__index_list)

    def __getitem__(self, index) -> dict:
        real_index = self.__index_list[index]
        return self.dataset.__getitem__(real_index)
