from pathlib import Path
import shutil
import pandas as pd
import torch
from torch.utils.data import Dataset
import pickle
import numpy as np
import torchvision.transforms.functional as F
from torchvision import transforms
import tarfile
import datetime
import pytz
from PIL import Image
from tqdm import tqdm
from wilds.common.utils import subsample_idxs
from wilds.common.metrics.all_metrics import Accuracy
from wilds.common.grouper import CombinatorialGrouper
from wilds.datasets.wilds_dataset import WILDSDataset

Image.MAX_IMAGE_PIXELS = 10000000000


categories = ["airport", "airport_hangar", "airport_terminal", "amusement_park", "aquaculture", "archaeological_site", "barn", "border_checkpoint", "burial_site", "car_dealership", "construction_site", "crop_field", "dam", "debris_or_rubble", "educational_institution", "electric_substation", "factory_or_powerplant", "fire_station", "flooded_road", "fountain", "gas_station", "golf_course", "ground_transportation_station", "helipad", "hospital", "impoverished_settlement", "interchange", "lake_or_pond", "lighthouse", "military_facility", "multi-unit_residential", "nuclear_powerplant", "office_building", "oil_or_gas_facility", "park", "parking_lot_or_garage", "place_of_worship", "police_station", "port", "prison", "race_track", "railway_bridge", "recreational_facility", "road_bridge", "runway", "shipyard", "shopping_mall", "single-unit_residential", "smokestack", "solar_farm", "space_facility", "stadium", "storage_tank", "surface_mine", "swimming_pool", "toll_booth", "tower", "tunnel_opening", "waste_disposal", "water_treatment_facility", "wind_farm", "zoo"]


class FMoWDataset(WILDSDataset):
    """
    The Functional Map of the World land use / building classification dataset.
    This is a processed version of the Functional Map of the World dataset originally sourced from https://github.com/fMoW/dataset.

    Support `split_scheme`
        'official': official split, which is equivalent to 'time_after_2016'
        `time_after_{YEAR}` for YEAR between 2002--2018

    Input (x):
        224 x 224 x 3 RGB satellite image.

    Label (y):
        y is one of 62 land use / building classes

    Metadata:
        each image is annotated with a location coordinate, timestamp, country code. This dataset computes region as a derivative of country code.

    Website: https://github.com/fMoW/dataset

    Original publication:
    @inproceedings{fmow2018,
      title={Functional Map of the World},
      author={Christie, Gordon and Fendley, Neil and Wilson, James and Mukherjee, Ryan},
      booktitle={CVPR},
      year={2018}
    }

    License:
        Distributed under FMoW Challenge Public License.
    
    """
    _dataset_name = 'fmow'
    _download_url = 'https://worksheets.codalab.org/rest/bundles/0xc59ea8261dfe4d2baa3820866e33d781/contents/blob/'
    _version = '1.0'

    def __init__(self, root_dir='data', download=False, split_scheme='official',
                 oracle_training_set=False, seed=111, use_ood_val=False):
        self._data_dir = self.initialize_data_dir(root_dir, download)

        self._split_dict = {'train': 0, 'val': 1, 'test': 2, 'ood_val': 3, 'ood_test': 4}
        self._split_names = {'train': 'Train', 'val': 'Val', 'test': 'Test', 'ood_val': 'OOD Val', 'ood_test': 'OOD Test'}
        if split_scheme=='official':
            split_scheme='time_after_2016'
        self._split_scheme = split_scheme
        self.oracle_training_set = oracle_training_set

        self.root = Path(self._data_dir)
        self.seed = int(seed)
        self._original_resolution = (224, 224)

        self.category_to_idx = {cat: i for i, cat in enumerate(categories)}

        self.metadata = pd.read_csv(self.root / 'rgb_metadata.csv')
        country_codes_df = pd.read_csv(self.root / 'country_code_mapping.csv')
        countrycode_to_region = {k: v for k, v in zip(country_codes_df['alpha-3'], country_codes_df['region'])}
        regions = [countrycode_to_region.get(code, 'Other') for code in self.metadata['country_code'].to_list()]
        self.metadata['region'] = regions
        all_countries = self.metadata['country_code']

        self.num_chunks = 101
        self.chunk_size = len(self.metadata) // (self.num_chunks - 1)

        if self._split_scheme.startswith('time_after'):
            year = int(self._split_scheme.split('_')[2])
            year_dt = datetime.datetime(year, 1, 1, tzinfo=pytz.UTC)
            self.test_ood_mask = np.asarray(pd.to_datetime(self.metadata['timestamp']) >= year_dt)
            # use 3 years of the training set as validation
            year_minus_3_dt = datetime.datetime(year-3, 1, 1, tzinfo=pytz.UTC)
            self.val_ood_mask = np.asarray(pd.to_datetime(self.metadata['timestamp']) >= year_minus_3_dt) & ~self.test_ood_mask
            self.ood_mask = self.test_ood_mask | self.val_ood_mask
        else:
            raise ValueError(f"Not supported: self._split_scheme = {self._split_scheme}")

        self._split_array = -1 * np.ones(len(self.metadata))
        for split in self._split_dict.keys():
            idxs = np.arange(len(self.metadata))
            if split == 'ood_test':
                test_mask = np.asarray(self.metadata['split'] == 'test')
                idxs = idxs[self.test_ood_mask & test_mask]
            elif split == 'ood_val':
                val_mask = np.asarray(self.metadata['split'] == 'val')
                idxs = idxs[self.val_ood_mask & val_mask]
            else:
                self.split_mask = np.asarray(self.metadata['split'] == split)
                idxs = idxs[~self.ood_mask & self.split_mask]

            if self.oracle_training_set and split == 'train':
                test_mask = np.asarray(self.metadata['split'] == 'test')
                unused_ood_idxs = np.arange(len(self.metadata))[self.ood_mask & ~test_mask]
                subsample_unused_ood_idxs = subsample_idxs(unused_ood_idxs, num=len(idxs)//2, seed=self.seed+2)
                subsample_train_idxs = subsample_idxs(idxs.copy(), num=len(idxs) // 2, seed=self.seed+3)
                idxs = np.concatenate([subsample_unused_ood_idxs, subsample_train_idxs])
            self._split_array[idxs] = self._split_dict[split]

        if use_ood_val:
            self._split_dict = {'train': 0, 'id_val': 1, 'test': 2, 'val': 3, 'ood_test': 4}
            self._split_names = {'train': 'Train', 'id_val': 'Val', 'test': 'Test', 'val': 'OOD Val', 'ood_test': 'OOD Test'}

        # filter out sequestered images from full dataset
        seq_mask = np.asarray(self.metadata['split'] == 'seq')
        # take out the sequestered images
        self._split_array = self._split_array[~seq_mask]
        self.full_idxs = np.arange(len(self.metadata))[~seq_mask]

        self._y_array = np.asarray([self.category_to_idx[y] for y in list(self.metadata['category'])])
        self.metadata['y'] = self._y_array
        self._y_array = torch.from_numpy(self._y_array).long()[~seq_mask]
        self._y_size = 1
        self._n_classes = 62

        # convert region to idxs
        all_regions = list(self.metadata['region'].unique())
        region_to_region_idx = {region: i for i, region in enumerate(all_regions)}
        self._metadata_map = {'region': all_regions}
        region_idxs = [region_to_region_idx[region] for region in self.metadata['region'].tolist()]
        self.metadata['region'] = region_idxs

        # make a year column in metadata
        year_array = -1 * np.ones(len(self.metadata))
        ts = pd.to_datetime(self.metadata['timestamp'])
        for year in range(2002, 2018):
            year_mask = np.asarray(ts >= datetime.datetime(year, 1, 1, tzinfo=pytz.UTC)) \
                        & np.asarray(ts < datetime.datetime(year+1, 1, 1, tzinfo=pytz.UTC))
            year_array[year_mask] = year - 2002
        self.metadata['year'] = year_array
        self._metadata_map['year'] = list(range(2002, 2018))

        self._metadata_fields = ['region', 'year', 'y']
        self._metadata_array = torch.from_numpy(self.metadata[self._metadata_fields].astype(int).to_numpy()).long()[~seq_mask]

        self._eval_groupers = [
                CombinatorialGrouper(dataset=self, groupby_fields=['region']),
                CombinatorialGrouper(dataset=self, groupby_fields=['year'])]

        self._metric = Accuracy()
        super().__init__(root_dir, download, split_scheme)

    def get_input(self, idx):
       """
       Returns x for a given idx.
       """
       idx = self.full_idxs[idx]
       batch_idx = idx // self.chunk_size
       within_batch_idx = idx % self.chunk_size
       img_batch = np.load(self.root / f'rgb_all_imgs_{batch_idx}.npy', mmap_mode='r')
       return img_batch[within_batch_idx]

    def eval(self, y_pred, y_true, metadata):
        all_results = {}
        all_results_str = ''
        for grouper in self._eval_groupers:
            results, results_str = self.standard_group_eval(
                self._metric,
                grouper,
                y_pred, y_true, metadata)
            all_results.update(results)
            all_results_str += results_str
        return all_results, all_results_str
