#!/usr/bin/env python
import argparse
from netCDF4 import Dataset
from hagelslag.util.Config import Config
from hagelslag.processing.TrackModeler import TrackModeler
from multiprocessing import Pool, Manager
from hagelslag.processing.EnsembleProducts import *
from hagelslag.util.make_proj_grids import read_ncar_map_file
from scipy.ndimage import gaussian_filter
import pandas as pd
import numpy as np
import traceback
from datetime import timedelta
from glob import glob
import os


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("config", help="Config filename.")
    parser.add_argument("-t", "--train", action="store_true", help="Train machine learning models.")
    parser.add_argument("-f", "--fore", action="store_true", help="Generate forecasts from machine learning models.")
    parser.add_argument("-g", "--grid", action="store_true",
                        help="Generate forecast grids for machine learning models.")
    args = parser.parse_args()
    required = ["ensemble_name", "train_data_path", "forecast_data_path", "member_files",
                "data_format", "condition_model_names", "condition_model_objs", "condition_input_columns",
                "condition_output_column", "group_col",
                "model_path", "metadata_columns", "data_json_path", "forecast_json_path",
                "load_models", "ensemble_members", "ml_grid_method", "neighbor_radius", "neighbor_sigma",
                "ensemble_consensus_path", "ensemble_variables", "ensemble_variable_thresholds", "ensemble_data_path",
                "size_dis_training_path", "netcdf_path", "watershed_variable", "model_map_file"]
    config = Config(args.config, required)
    if not hasattr(config, "run_date_format"):
        config.run_date_format = "%Y%m%d-%H%M"
    if any([args.train, args.fore]):
        if not hasattr(config, "weighting_function"):
            config.weighting_function = None
        track_modeler = TrackModeler(config.ensemble_name,
                                     config.train_data_path,
                                     config.forecast_data_path,
                                     config.member_files,
                                     config.start_dates,
                                     config.end_dates,
                                     config.weighting_function,
                                     config.model_map_file,
                                     config.group_col)
        if args.train:
            train_models(track_modeler, config)
            training_data_percentiles(config.ensemble_members,
                                        config.ensemble_name,
                                        config.watershed_variable,
                                        config.train_data_path,
                                        config.size_dis_training_path,
                                        config.weighting_function)
        if args.fore:
            forecasts = make_forecasts(track_modeler, config)
            output_forecasts_csv(forecasts, track_modeler, config)
    if args.grid:
        generate_ml_grids(config, mode="forecast")
    return


def train_models(track_modeler, config):
    """
    Trains machine learning models to predict size, whether or not the event occurred, and track errors.

    Args:
        track_modeler (hagelslag.TrackModeler): an initialized TrackModeler object
        config: Config object
    """
    track_modeler.load_data(mode="train", format=config.data_format)
    track_modeler.fit_condition_threshold_models(config.condition_model_names,
                                                 config.condition_model_objs,
                                                 config.condition_input_columns,
                                                 config.condition_output_column,
                                                 config.condition_threshold)
    track_modeler.fit_size_distribution_models(config.size_distribution_model_names,
                                               config.size_distribution_model_objs,
                                               config.size_distribution_input_columns,
                                               output_columns=config.size_distribution_output_columns)
    track_modeler.save_models(config.model_path)
    return


def make_forecasts(track_modeler, config):
    """
    Generate predictions from all machine learning models.

    Args:
        track_modeler (hagelslag.processing.TrackModeler object): TrackModeler object with configuration information
        config (hagelslag.util.Config object): Configuration information
    Returns:
        dictionary containing forecast values.
    """
    print("Load data")
    track_modeler.load_data(mode="forecast", format=config.data_format)
    if config.load_models:
        print("Load models")
        track_modeler.load_models(config.model_path)
    forecasts = {}
    print("Condition forecasts")
    forecasts["condition"] = track_modeler.predict_condition_models(config.condition_model_names,
                                                                    config.condition_input_columns,
                                                                    config.metadata_columns)

    print("Size Distribution Forecasts")
    forecasts["dist"] = track_modeler.predict_size_distribution_models(config.size_distribution_model_names,
                                                                    config.size_distribution_input_columns,
                                                                    config.metadata_columns, 
                                                                    location=config.size_distribution_loc)
    return forecasts


def output_forecasts(forecasts, track_modeler, config):
    """
    Write forecasts out to GeoJSON files in parallel.

    Args:
        forecasts: dict
            dictionary containing forecast values organized by type
        track_modeler: hagelslag.processing.TrackModeler
            TrackModeler object
        config:
            Config object
    Returns:

    """
    track_modeler.output_forecasts_json_parallel(forecasts,
                                                 config.condition_model_names,
                                                 config.size_model_names,
                                                 config.size_distribution_model_names,
                                                 config.track_model_names,
                                                 config.data_json_path,
                                                 config.forecast_json_path,
                                                 config.num_procs)
    return


def output_forecasts_csv(forecasts, track_modeler, config):
    track_modeler.output_forecasts_csv(forecasts, "forecast", config.forecast_csv_path,
                                       run_date_format=config.run_date_format)
    return


def generate_ml_grids(config, mode="forecast"):
    """
    Creates gridded machine learning model forecasts and writes them to GRIB2 files.

    Args:
        config: hsforecast Config object with relevant info
        mode: train or forecast

    Returns:

    """
    pool = Pool(config.num_procs)
    run_dates = pd.date_range(start=config.start_dates[mode],
                                 end=config.end_dates[mode],
                                 freq='1D')
    ml_model_list = config.size_distribution_model_names
    print(ml_model_list)
    print()
    print('Size Distribution coming from: {0}'.format(config.size_dis_training_path))
    print() 

    ml_var = "hail"
    for run_date in run_dates:
        start_date = run_date + timedelta(hours=config.start_hour)
        end_date = run_date + timedelta(hours=config.end_hour)
        for member in config.ensemble_members:
            args = (config.ensemble_name, ml_model_list, member, run_date, ml_var, start_date, end_date,
                    config.single_step, config.neighbor_condition_model, config.forecast_csv_path,
                    config.netcdf_path, config.grib_path, config.model_map_file,
                    config.size_dis_training_path, config.watershed_variable)
            pool.apply_async(generate_ml_member_grid, args)
    pool.close()
    pool.join()
    return


def generate_ml_member_grid(ensemble_name, model_names, member, run_date, variable, start_date, end_date,
                            single_step, neighbor_condition_model, forecast_csv_path, netcdf_path,
                            grib_path, map_file, size_distribution_training_path, watershed_obj):
    """
    Convert the machine learning model object probabilities and size distributions to gridded fields.

    Args:
        ensemble_name: Name of the ensemble
        model_names: Names of the machine learning models
        member: name the ensemble member
        run_date: initial date of the model run
        variable: Name of the machine learning output field
        start_date: first date of the forecast period
        end_date: last date of the forecast period
        single_step: whether the model output is in single netCDF files per timestep or aggregated
        neighbor_condition_model: Model for whether or not hail occurs
        forecast_csv_path: Path to forecast csv files
        netcdf_path: Path to netCDF patches
        grib_path: Path for saving grib2 files
        map_file: Path to map projection file or None
        size_distribution_training_path: Path to size distribution percentile files.
        watershed_obj: Name of the variable used for watershed object extraction.

    Returns:

    """
    try:
        if exists(forecast_csv_path + "hail_forecasts_{0}_{1}_{2}.csv".format(ensemble_name, member,
                                                                              run_date.strftime("%Y%m%d-%H%M"))):
            ep = EnsembleMemberProduct(ensemble_name, model_names[0], member, run_date, variable,
                                       start_date, end_date, None, single_step,
                                       size_distribution_training_path,
                                       watershed_obj, map_file=map_file,
                                       condition_model_name=neighbor_condition_model)
            for model_name in model_names:
                ep.model_name = model_name
                ep.load_forecast_csv_data(forecast_csv_path)
                ep.load_forecast_netcdf_data(netcdf_path)
                ep.quantile_match()
                #ep.load_data(num_samples=num_samples, percentiles=ml_grid_percentiles)
                grib_objects = ep.encode_grib2_data()
                if grib_objects is None:
                    return
                if not os.access(grib_path + run_date.strftime("/%Y%m%d/"), os.R_OK):
                    try:
                        os.mkdir(grib_path + run_date.strftime("/%Y%m%d/"))
                    except OSError:
                        pass
                ep.write_grib2_files(grib_objects, grib_path + run_date.strftime("%Y%m%d/"))
        else:
            print("No model runs on " + run_date.strftime("%Y%m%d") + " " + member)
    except Exception as e:
        print(traceback.format_exc())
        raise e
    return


def training_data_percentiles(member_list, ensemble_name, watershed_obj, csv_data_path, 
                            csv_outpath, weighting_function=None,
                            percentiles=np.linspace(0.1, 99.9, 100)):
    """
    Creates watershed object distribution of sizes for a given set of training netcdf patches

    Args:
            member_list (list of str): List of ensemble members in netcdf file
            ensemble_name (str): Name of ensemble member in netcdf file
            watershed_obj (str): Watershed object used in config files
            csv_data_path (str): Path to csv files
            csv_outpath (str): Path to output csv 
    Returns:
            obj_per_vals (list): Distribution of watershed object values over training data
            
    """

    print('Creating Distribution of Watershed Object Sizes')
    ensemble_obj_data = []

    #########################
    #Individual Member Distributions 
    #########################
    for member in member_list:
        member_obj_data = []
        member_files = sorted(glob(csv_data_path+'*step*{0}*{1}*.csv'.format(ensemble_name,member)))
        dataset = pd.concat(map(pd.read_csv, member_files),ignore_index=True,sort='True')
        match_inds = np.where(dataset["Matched"] == 1)[0]
        match_dataset = dataset.loc[match_inds,:]
        if weighting_function is None:
            objs_stat = match_dataset.loc[:,"{0}_max".format(watershed_obj)].dropna()
            member_obj_data.append(objs_stat)
        else:
            training_weights = weighting_function(match_dataset)
            weight_inds = np.where(training_weights >= 1.0)[0]
            weights_objs_stat = match_dataset.reindex(index=weight_inds,columns=["{0}_max".format(watershed_obj)]).dropna()
            member_obj_data.append(weights_objs_stat)
        if member_obj_data:
            obj_values = np.concatenate(member_obj_data)
            print('{0} {1} model objects: {2}'.format(watershed_obj,member,obj_values.shape[0]))
            print() 
            obj_percent_val = np.percentile(obj_values, percentiles)
            dataframe_values = {'Obj_Values':obj_percent_val, 'Percentiles':percentiles}
            obj_value_percentile = pd.DataFrame(data=dataframe_values)
            if csv_outpath:
                csv_path = csv_outpath+'{0}_{1}_{2}_Size_Distribution.csv'.format(ensemble_name,
                                    watershed_obj,member)
                obj_value_percentile.to_csv(csv_path, index=False)
                print('Output csv: {0}'.format(csv_path))
            ensemble_obj_data.append(obj_values)
    
    #########################
    #Ensemble Distribution 
    #########################
    if ensemble_obj_data:
        ens_obj_values = [obj for member in ensemble_obj_data for obj in member]
        print() 
        print('Total ensemble {0} model objects: {1}'.format(watershed_obj,np.shape(ens_obj_values)[0]))
        ens_obj_percent_val = np.percentile(ens_obj_values, percentiles)
        dataframe_values = {'Obj_Values':ens_obj_percent_val, 'Percentiles':percentiles}
        ens_obj_value_percentile = pd.DataFrame(data=dataframe_values)
        if csv_outpath:
            csv_path = csv_outpath+'{0}_{1}_Size_Distribution.csv'.format(ensemble_name,
                                    watershed_obj)
            ens_obj_value_percentile.to_csv(csv_path, index=False)
            print('Output csv: {0}'.format(csv_path))
    return

if __name__ == "__main__":
    main()
