import os
import shutil
import torch
import numpy as np

class WILDSDataset:
    """
    Shared dataset class for all WILDS datasets.
    Each data point in the dataset is an (x, y, metadata) tuple, where:
    - x is the input features
    - y is the target
    - metadata is a vector of relevant information, e.g., domain.
      For convenience, metadata also contains y.
    """
    DEFAULT_SPLITS = {'train': 0, 'val': 1, 'test': 2}
    DEFAULT_SPLIT_NAMES = {'train': 'Train', 'val': 'Validation', 'test': 'Test'}

    def __init__(self, root_dir, download, split_scheme):
        if len(self._metadata_array.shape) == 1:
            self._metadata_array = self._metadata_array.unsqueeze(1)
        self.check_init()

    def __len__(self):
        return len(self.y_array)

    def __getitem__(self, idx):
        # Any transformations are handled by the WILDSSubset
        # since different subsets (e.g., train vs test) might have different transforms
        x = self.get_input(idx)
        y = self.y_array[idx]
        metadata = self.metadata_array[idx]
        return x, y, metadata

    def get_input(self, idx):
        """
        Args:
            - idx (int): Index of a data point
        Output:
            - x (Tensor): Input features of the idx-th data point
        """
        raise NotImplementedError

    def eval(self, y_pred, y_true, metadata):
        """
        Args:
            - y_pred (Tensor): Predicted targets
            - y_true (Tensor): True targets
            - metadata (Tensor): Metadata
        Output:
            - results (dict): Dictionary of results
            - results_str (str): Pretty print version of the results
        """
        raise NotImplementedError

    def get_subset(self, split, frac=1.0, transform=None):
        """
        Args:
            - split (str): Split identifier, e.g., 'train', 'val', 'test'.
                           Must be in self.split_dict.
            - frac (float): What fraction of the split to randomly sample.
                            Used for fast development on a small dataset.
            - transform (function): Any data transformations to be applied to the input x.
        Output:
            - subset (WILDSSubset): A (potentially subsampled) subset of the WILDSDataset.
        """
        if split not in self.split_dict:
            raise ValueError(f"Split {split} not found in dataset's split_dict.")
        split_mask = self.split_array == self.split_dict[split]
        split_idx = np.where(split_mask)[0]
        if frac < 1.0:
            num_to_retain = int(np.round(float(len(split_idx)) * frac))
            split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain])
        subset = WILDSSubset(self, split_idx, transform)
        return subset

    def check_init(self):
        """
        Convenience function to check that the WILDSDataset is properly configured.
        """
        required_attrs = ['_dataset_name', '_data_dir',
                          '_split_scheme', '_split_array',
                          '_y_array', '_y_size',
                          '_metadata_fields', '_metadata_array']
        for attr_name in required_attrs:
            assert hasattr(self, attr_name), f'WILDSDataset is missing {attr_name}.'

        # Check that data directory exists
        if not os.path.exists(self.data_dir):
            raise ValueError(
                f'{self.data_dir} does not exist yet. Please generate the dataset first.')

        # Check splits
        assert self.split_dict.keys()==self.split_names.keys()
        assert 'train' in self.split_dict
        assert 'val' in self.split_dict

        # Check that required arrays are Tensors
        assert isinstance(self.y_array, torch.Tensor), 'y_array must be a torch.Tensor'
        assert isinstance(self.metadata_array, torch.Tensor), 'metadata_array must be a torch.Tensor'

        # Check that dimensions match
        assert len(self.y_array) == len(self.metadata_array)
        assert len(self.split_array) == len(self.metadata_array)

        # Check metadata
        assert len(self.metadata_array.shape) == 2
        assert len(self.metadata_fields) == self.metadata_array.shape[1]
        # For convenience, include y in metadata_fields if y_size == 1
        if self.y_size == 1:
            assert 'y' in self.metadata_fields

    @property
    def dataset_name(self):
        """
        A string that identifies the dataset, e.g., 'amazon', 'camelyon17'.
        """
        return self._dataset_name

    @property
    def version(self):
        """
        A string that identifies the dataset version, e.g., '1.0'.
        """
        return self._version

    @property
    def download_url(self):
        """
        URL for downloading the dataset archive.
        If None, the dataset cannot be downloaded automatically
        (e.g., because it first requires accepting a usage agreement).
        """
        return getattr(self, '_download_url', None)

    @property
    def data_dir(self):
        """
        The full path to the folder in which the dataset is stored.
        """
        return self._data_dir

    @property
    def collate(self):
        """
        Torch function to collate items in a batch.
        By default returns None -> uses default torch collate.
        """
        return getattr(self, '_collate', None)

    @property
    def split_scheme(self):
        """
        A string identifier of how the split is constructed,
        e.g., 'standard', 'in-dist', 'user', etc.
        """
        return self._split_scheme

    @property
    def split_dict(self):
        """
        A dictionary mapping splits to integer identifiers (used in split_array),
        e.g., {'train': 0, 'val': 1, 'test': 2}.
        Keys should match up with split_names.
        """
        return getattr(self, '_split_dict', WILDSDataset.DEFAULT_SPLITS)

    @property
    def split_names(self):
        """
        A dictionary mapping splits to their pretty names,
        e.g., {'train': 'Train', 'val': 'Validation', 'test': 'Test'}.
        Keys should match up with split_dict.
        """
        return getattr(self, '_split_names', WILDSDataset.DEFAULT_SPLIT_NAMES)

    @property
    def split_array(self):
        """
        An array of integers, with split_array[i] representing what split the i-th data point
        belongs to.
        """
        return self._split_array

    @property
    def y_array(self):
        """
        A Tensor of targets (e.g., labels for classification tasks),
        with y_array[i] representing the target of the i-th data point.
        y_array[i] can contain multiple elements.
        """
        return self._y_array

    @property
    def y_size(self):
        """
        The number of dimensions/elements in the target, i.e., len(y_array[i]).
        For standard classification/regression tasks, y_size = 1.
        For multi-task or structured prediction settings, y_size > 1.
        Used for logging and to configure models to produce appropriately-sized output.
        """
        return self._y_size

    @property
    def n_classes(self):
        """
        Number of classes for single-task classification datasets.
        Used for logging and to configure models to produce appropriately-sized output.
        None by default.
        Leave as None if not applicable (e.g., regression or multi-task classification).
        """
        return getattr(self, '_n_classes', None)

    @property
    def is_classification(self):
        """
        Boolean. True if the task is classification, and false otherwise.
        Used for logging purposes.
        """
        return (self.n_classes is not None)

    @property
    def metadata_fields(self):
        """
        A list of strings naming each column of the metadata table, e.g., ['hospital', 'y'].
        Must include 'y'.
        """
        return self._metadata_fields

    @property
    def metadata_array(self):
        """
        A Tensor of metadata, with the i-th row representing the metadata associated with
        the i-th data point. The columns correspond to the metadata_fields defined above.
        """
        return self._metadata_array

    @property
    def metadata_map(self):
        """
        An optional dictionary that, for each metadata field, contains a list that maps from
        integers (in metadata_array) to a string representing what that integer means.
        This is only used for logging, so that we print out more intelligible metadata values.
        Each key must be in metadata_fields.
        For example, if we have
            metadata_fields = ['hospital', 'y']
            metadata_map = {'hospital': ['East', 'West']}
        then if metadata_array[i, 0] == 0, the i-th data point belongs to the 'East' hospital
        while if metadata_array[i, 0] == 1, it belongs to the 'West' hospital.
        """
        return getattr(self, '_metadata_map', None)

    @property
    def original_resolution(self):
        """
        Original image resolution for image datasets.
        """
        return getattr(self, '_original_resolution', None)

    def initialize_data_dir(self, root_dir, download):
        """
        Helper function for downloading/updating the dataset if required.
        Note that we only do a version check for datasets where the download_url is set.
        Currently, this includes all datasets except Yelp.
        Datasets for which we don't control the download, like Yelp,
        might not handle versions similarly.
        """
        os.makedirs(root_dir, exist_ok=True)

        data_dir = os.path.join(root_dir, f'{self.dataset_name}_v{self.version}')
        version_file = os.path.join(data_dir, f'RELEASE_v{self.version}.txt')
        current_major_version, current_minor_version = tuple(map(int, self.version.split('.')))

        # If the data_dir exists and contains the right RELEASE file,
        # we assume the dataset is correctly set up
        if os.path.exists(data_dir) and os.path.exists(version_file):
            return data_dir

        # If the data_dir exists and is not empty, and the download_url is set,
        # we assume the dataset is correctly set up
        if ((os.path.exists(data_dir)) and
            (len(os.listdir(data_dir)) > 0) and
            (self.download_url is None)):
            return data_dir

        # Otherwise, check if there's an older version of the dataset around
        old_major_version, old_minor_version = -1, -1
        old_folders = [
            f for f in os.listdir(root_dir) if (
                os.path.isdir(os.path.join(root_dir, f)) and
                f.startswith(self.dataset_name))]
        for old_folder in old_folders:
            prefix = f'{self.dataset_name}_v'
            try:
                version = old_folder.split(prefix)[1]
                if os.path.exists(
                    os.path.join(root_dir, old_folder, f'RELEASE_v{version}.txt')):
                    major_version, minor_version = tuple(map(int, version.split('.')))
                    if ((old_major_version < major_version) or
                        ((old_major_version == major_version) and
                         (old_minor_version < minor_version))):
                         old_major_version, old_minor_version = major_version, minor_version
                         latest_existing_data_dir = os.path.join(root_dir, old_folder)
            except:
                continue

        do_download = False

        # No existing dataset
        if (old_major_version == -1):
            if download == False:
                if self.download_url is None:
                    raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. {self.dataset_name} cannot be automatically downloaded. Please download it manually.')
                else:
                    raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. Run with --download to download the dataset. This might take some time for large datasets.')
            else:
                do_download = True

        # Older major version:
        # Prompt for update, ignore --download
        elif (old_major_version < current_major_version):
            print(
                '***********\n'
                f'{self.dataset_name} has been updated to a new major version.\n'
                f'We recommend updating the dataset.\n')
            confirm = input(f'Will you update the dataset now? This might take some time for large datasets. (y/n)\n').lower()
            if confirm == 'y':
                do_download = True

        # Same major version, older minor version:
        # Notify user but do not prompt unless --download is set
        elif ((old_major_version == current_major_version) and
              (old_minor_version < current_minor_version)):
            print(
                '***********\n'
                f'{self.dataset_name} has been updated to a new minor version.\n')
            if download == False:
                print(
                    'Run with --download to update the dataset. This might take some time for large datasets.\n'
                    '***********\n')
            else:
                do_download = True

        # Download if necessary
        if do_download == False:
            data_dir = latest_existing_data_dir
        else:
            if self.download_url is None:
                raise ValueError(f'Sorry, {self.dataset_name} cannot be automatically downloaded. Please download it manually.')

            from wilds.datasets.download_utils import download_and_extract_archive
            print(f'Downloading dataset to {data_dir}...')
            try:
                download_and_extract_archive(
                    url=self.download_url,
                    download_root=data_dir,
                    filename='archive.tar.gz',
                    remove_finished=True)
            except:
                print(f"\n{os.path.join(data_dir, 'archive.tar.gz')} appears to be corrupted. Please try deleting it and rerunning this command.\n")

        return data_dir

    @staticmethod
    def standard_group_eval(metric, grouper, y_pred, y_true, metadata):
        """
        Args:
            - metric (Metric): Metric to use for eval
            - grouper (Grouper): Grouper object that converts metadata into groups
            - y_pred (Tensor): Predicted targets
            - y_true (Tensor): True targets
            - metadata (Tensor): Metadata
        Output:
            - results (dict): Dictionary of results
            - results_str (str): Pretty print version of the results
        """
        g = grouper.metadata_to_group(metadata)
        results = {
            **metric.compute(y_pred, y_true),
            **metric.compute_group_wise(y_pred, y_true, g, grouper.n_groups)
        }
        results_str = (
            f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n"
        )
        for group_idx in range(grouper.n_groups):
            if results[metric.group_count_field(group_idx)] == 0:
                continue
            results_str += (
                f'  {grouper.group_str(group_idx)}  '
                f"[n = {results[metric.group_count_field(group_idx)]:6.0f}]:\t"
                f"{metric.name} = {results[metric.group_metric_field(group_idx)]:5.3f}\n")
        results_str += f"Worst-group {metric.name}: {results[metric.worst_group_metric_field]:.3f}\n"
        return results, results_str


class WILDSSubset(WILDSDataset):
    def __init__(self, dataset, indices, transform):
        """
        This acts like torch.utils.data.Subset, but on WILDSDatasets.
        We pass in transform explicitly because it can potentially vary at
        training vs. test time, if we're using data augmentation.
        """
        self.dataset = dataset
        self.indices = indices
        inherited_attrs = ['_dataset_name', '_data_dir', '_collate',
                           '_split_scheme', '_split_dict', '_split_names',
                           '_y_size', '_n_classes',
                           '_metadata_fields', '_metadata_map']
        for attr_name in inherited_attrs:
            if hasattr(dataset, attr_name):
                setattr(self, attr_name, getattr(dataset, attr_name))
        self.transform = transform

    def __getitem__(self, idx):
        x, y, metadata = self.dataset[self.indices[idx]]
        if self.transform is not None:
            x = self.transform(x)
        return x, y, metadata

    def __len__(self):
        return len(self.indices)

    @property
    def split_array(self):
        return self.dataset._split_array[self.indices]

    @property
    def y_array(self):
        return self.dataset._y_array[self.indices]

    @property
    def metadata_array(self):
        return self.dataset.metadata_array[self.indices]

    def eval(self, y_pred, y_true, metadata):
        return self.dataset.eval(y_pred, y_true, metadata)
