from yqn_config.base_config import BaseConfig
from yqn_pytorch_framework.train.base_dataset import BaseDataset


class ${class_name}Dataset(BaseDataset):

    def __init__(self, config: BaseConfig, data_flag):
        super(${class_name}Dataset, self).__init__(config)
        if data_flag == 'train':
            self.input_features, self.label_list = self.get_files(self.config.train_dir, '')
        else:
            self.input_features, self.label_list = self.get_files(self.config.val_dir, '')

    def get_item_size(self):
        return len(self.input_features)

    def get_item(self, index):
        # TODO fill ${module_name} Dataset get_item
        pass
