#!/usr/bin/env python
from hagelslag.util.make_proj_grids import * 
from hagelslag.util.Config import Config
from scipy.ndimage import gaussian_filter
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from copy import deepcopy
from glob import glob
import pandas as pd
import numpy as np
import matplotlib
import argparse
import random
import pickle
import os

try:
    from ncepgrib2 import Grib2Encode, dump
    grib_support = True
except ImportError("ncepgrib2 not available"):
    grib_support = False   

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("config", help="Configuration file")
    parser.add_argument("-t", "--train", action="store_true", help="Train calibration models.")
    parser.add_argument("-f", "--forecast", action="store_true", help="Generate forecasts from calibration models.")
    parser.add_argument("-y", "--hourly", action="store_true", help="If False will not output hourly forecasts.")
    parser.add_argument("-g", "--grib_out", action="store_true", help="Generate grib2 files.")
    parser.add_argument("-n", "--netcdf_out", action="store_true", help="Generate netcdf files.")
    parser.add_argument("-p", "--plot_out", action="store_true", help="Plot calibrated forecasts.")

    args = parser.parse_args()

    required = ["calibration_model_names", "calibration_model_objs", "ensemble_name", 
                "forecast_model_names","train_data_path", "forecast_data_path", 
                "target_data_path", "target_data_names", "model_path", 
                "run_date_format","size_threshold","forecast_out_path","map_file"] 
    config = Config(args.config, required)

    stride = 1
    smoothing = 14

    if not hasattr(config, "run_date_format"):
        config.run_date_format = "%y%m%d"
    if not hasattr(config, "sector"):
        config.sector = None
    if config.map_file[-3:] == "map":
        proj_dict, grid_dict = read_arps_map_file(config.map_file)
    else:
        proj_dict, grid_dict = read_ncar_map_file(config.map_file)

    if args.train:
        trained_models = train_calibration(config)
        saving_cali_models(trained_models,config)
    
    if args.forecast: 
        run_datetime = pd.date_range(start=config.start_dates["forecast"],
                                 end=config.end_dates["forecast"],
                                 freq='1D')
        
        target_dataset_models = load_cali_models(config)
        run_date = []
        for date in run_datetime: 
            test_data_files = config.forecast_data_path+\
                    "20{0}/netcdf/".format(date.strftime(config.run_date_format))
            if os.path.exists(test_data_files): 
                run_date.append(date)
        
        run_date_str = [r.strftime(config.run_date_format) for r in run_date]
        print()
        print('Forecast run date(s): {0}'.format(run_date_str))
        print('Creating calibrated forecasts')  
        print()
        print('Full 24 hour forecast')

        start_time = config.start_hour
        end_time = config.end_hour  
                
        target_NMEP =  create_calibration_forecasts(
                                            target_dataset_models,
                                            start_time,end_time,
                                            run_date_str,
                                            config)

        for size in config.size_threshold:
            for cali_model_name in config.calibration_model_names:
                if args.plot_out:    
                    forecast_plotting(target_NMEP[size][cali_model_name],
                                    proj_dict,grid_dict,start_time,
                                    end_time,stride,size,smoothing,
                                    run_date_str,config)

                if args.grib_out:
                    output_grib2(
                            target_NMEP[size][cali_model_name],
                            proj_dict,grid_dict,start_time,end_time,
                            stride,size,run_date,config.target_data_names,
                            smoothing,config)
                
                if args.netcdf_out:
                    output_netcdf(
                            target_NMEP[size][cali_model_name],
                            proj_dict,grid_dict,start_time,end_time,
                            stride,size,run_date_str,config.target_data_names,
                            smoothing,config)
           
        print()
        if args.hourly: 
            print('Hourly forecasts') 
            for hour in range(12,34):
                start_time = hour
                end_time = hour+4
                    
                target_NMEP_hour = create_calibration_forecasts(
                                                                target_dataset_models,
                                                                start_time,end_time,
                                                                run_date_str,
                                                                config)
                for size in config.size_threshold:
                    for cali_model_name in config.calibration_model_names:
                        if args.plot_out:    
                            forecast_plotting(target_NMEP_hour[size][cali_model_name],
                                    proj_dict,grid_dict,start_time,
                                    end_time,stride,size,smoothing,
                                    run_date_str,config)
        
                        if args.grib_out:
                            output_grib2(
                                target_NMEP_hour[size][cali_model_name],
                                proj_dict,grid_dict,start_time,end_time,
                                stride,size,run_date,config.target_data_names,
                                smoothing,config)
                
                        if args.netcdf_out:
                            output_netcdf(
                                target_NMEP_hour[size][cali_model_name],
                                proj_dict,grid_dict,start_time,end_time,
                                stride,size,run_date_str,config.target_data_names,
                                smoothing,config)
        
    return


def train_calibration(config):
    """
    Loads Neighboorhood Maximum Ensemble Probability (NMEP) forecasts.
    Trains machine learning models to calibrate NMEP forecasts towards 
    a chosen target dataset. 
    """
    run_dates = pd.date_range(start=config.start_dates["train"],
                                 end=config.end_dates["train"],
                                 freq='1D').strftime(config.run_date_format)
    
    target_calib_models = {}
    print()
    print('Loading Data')

    for size_index,size in enumerate(config.size_threshold):
        target_calib_models[size] = {}
        train_files, target_files = [], []
        for date in run_dates: 
            train_data_files = glob(config.train_data_path+ \
                    "20{2}/netcdf/*{0}*unsmoothed*_{1}_*{2}*{3}*{4}.nc".format(
                        config.forecast_model_names,size,date,
                        config.start_hour,config.end_hour))
            if len(train_data_files) < 1:
                continue
            if config.sector:
                target_data_files = glob(config.target_data_path+'{0}*{1}*{2}*.nc'.format(
                            date,size,config.sector))   
            else:
                target_data_files = glob(config.target_data_path+'{0}*{1}*.nc'.format(
                            date,size))
            if len(target_data_files) < 1:
                continue
            train_files.append(train_data_files[0])
            target_files.append(target_data_files[0])
        
        date_indices = [index for index in range(len(train_files))]
        percent_train_indices = int(len(train_files)*0.70)
        t_data = [Dataset(x).variables["Data"][:] for x in train_files] 
        tar_data = [Dataset(x).variables["24_Hour_All_12z_12z"][:] for x in target_files]        
        print()
        print('Number of files:')
        print('Train (70%): {0}'.format(int(len(t_data)*0.70)))
        print('Validate (30%): {0}'.format(int(len(t_data)*0.30)))
        print()
        for ind,model_name in enumerate(config.calibration_model_names):
            bs = []
            random_models = []
            print('Random Cross-Validation, {0} >{1}mm'.format(model_name,size))  
            random_seed = random.sample(range(1, 100), 10)
            for s,seed in enumerate(random_seed):
                np.random.seed(seed)
                print('Index',s, 'Random Seed', seed)
                train_indices = np.random.choice(date_indices, percent_train_indices, replace=False)
                test_indices = [ind for ind in date_indices if ind not in train_indices]
            
                train_data = np.array(t_data)[train_indices].ravel()
                target_train_data = np.array(tar_data)[train_indices].ravel()
            
                val_data = np.array(t_data)[test_indices].ravel()
                target_val_data = np.array(tar_data)[test_indices].ravel()
            
                model = deepcopy(config.calibration_model_objs[ind])
                model.fit(train_data,target_train_data)
                random_models.append(model)
            
                predict = model.transform(val_data)
                
                #plt.figure(figsize=(9, 6))
                #plt.plot(sorted(val_data),model.transform(sorted(val_data)))
                #plt.xlabel('data')
                #plt.ylabel('calibrated')
                #plt.show()
                #plt.close()

                print(brier_score(predict, target_val_data))
                bs.append(brier_score(predict, target_val_data))
        
            best_bs = np.argmin(bs)
            target_calib_models[size][model_name] = np.array(random_models)[best_bs]
            print('Lowest Brier Score: {0}'.format(np.array(bs)[best_bs]))
            print()
    print()
    return target_calib_models

def brier_score(pred,obs):
    """
    Args: 
        pred (1D array): model predictions
        obs (1D array): observations
    Returns:
        bs (float): brier score
    """
    bs = np.nanmean((pred - obs)**2.0)
    return bs

def saving_cali_models(target_calib_model, config):
    """
    Save calibration machine learning models to pickle files.
    
    Args:
        target_calib_model (dataframe): Pandas dataframe containing information 
            about the different trained calibration models
    """

    print('Saving Models')
    for size, calibration_model in target_calib_model.items():
        for model_name, model_objs in calibration_model.items():
            out_cali_filename = config.model_path + \
                            '{0}_{1}_{2}mm_calibration_random_cv.pkl'.format(
                            model_name.replace(" ", "-"),
                            config.target_data_names,size)
            print('Writing out: {0}'.format(out_cali_filename)) 
            pickle.dump(model_objs,open(out_cali_filename,'wb'))
    return

def load_cali_models(config):
    """
    Load calibration models from pickle files.
    """

    print()
    print("Load models")
    target_calib_model = {}
    
    for size in config.size_threshold:
        target_calib_model[size] = {}
        for model_name in config.calibration_model_names:
            target_file = config.model_path + \
                        '{0}_{1}_{2}mm_calibration_random_cv.pkl'.format(
                        model_name.replace(" ", "-"),
                        config.target_data_names,size)
            target_calib_model[size][model_name] =\
                    pickle.load(open(target_file,'rb'))
            print("Opening {0}".format(target_file)) 
    
    return target_calib_model 


def create_calibration_forecasts(target_calib_model,start_hour,end_hour,
                                run_dates,config):
    """
    Generate calibrated Neighborhood Maximum Ensemble Probability (NMEP) predictions. 
        
    Args:
        target_calib_model (dataframe): Pandas dataframe containing information 
            about the different trained calibration models
        start_hour (str): String of format HH for selecting test files
        end_hour (str): String of format HH for selecting test files
        run_dates (list of strings): Dates of test files

    Returns:
        A dictionary containing calibrated forecast NMEP values. 
    """
    
    ################################################
    
    lon = None
    lat = None

    target_cali_forecast = {}
    
    if config.size_threshold:
        for size in config.size_threshold:
            test_file = []
            target_cali_forecast[size] = {}
            for date in run_dates: 
                test_data_files = glob(config.forecast_data_path+\
                    "20{3}/netcdf/{0}*Hail*{1}*NMEP*unsmoothed*{2}*{3}*{4}*{5}.nc".format(
                    config.ensemble_name,config.forecast_model_names,size,date,
                    start_hour,end_hour))
                if len(test_data_files)>=1: 
                    test_file.append(test_data_files[0])
            all_test_files = [Dataset(x).variables["Data"][:] for x in test_file]
            test_data = np.array(all_test_files).flatten()
            zero_inds = np.where(test_data == 0.0)[0]
            for model_name in config.calibration_model_names:
                data_shape = (len(run_dates),np.shape(all_test_files[0])[0],\
                            np.shape(all_test_files[0])[1]) 
                predict = target_calib_model[size][model_name].transform(test_data)
                predict[zero_inds] = 0.0
                target_cali_forecast[size][model_name] = predict.reshape(data_shape)
                
    return target_cali_forecast

def forecast_plotting(forecast,proj_dict,grid_dict,start_hour,
                    end_hour,stride,size,smoothing,run_date,config):
    """
    Plot calibrated predictions. 
    
    Args: 
        forecast (dict): generated calibrated NMEP forecasts
        proj_dict (dict): projection information of forecasts
        grid_dict (dict): gridded information of forecasts
        start_hour (int): Beginning hour of chosen forecast period
        end_hour (int): Ending hour of chosen forecast period
        stride (int): Smoohing factor
        size (int): hail size threshold
        smoothing (int): sigma value for Gaussian smoother
        run_date (datetime): Valid date for forecast
    """
    map_data = make_proj_grids(proj_dict,grid_dict)
    lons = map_data["lon"]
    lats = map_data["lat"]

    plt.figure(figsize=(9, 6))

    m = Basemap(projection='lcc',
                area_thresh=10000.,
                resolution="l",
                lon_0=proj_dict["lon_0"],
                lat_0=proj_dict["lat_0"],
                lat_1=proj_dict["lat_1"],
                lat_2=proj_dict["lat_0"],
                llcrnrlon=lons[0,0],
                llcrnrlat=lats[0,0],
                urcrnrlon=lons[-1,-1],
                urcrnrlat=lats[-1,-1])
    
    x1,y1 = m(lons,lats)
    x1, y1 = x1[::stride,::stride], y1[::stride,::stride]
    cmap = matplotlib.colors.ListedColormap(['#DBC6BD','#AD8877','#FCEA8D', 'gold','#F76E67','#F2372E',
                                            '#F984F9','#F740F7','#AE7ADD','#964ADB',
                                            '#99CCFF', '#99CCFF','#3399FF'])                    
    levels = [0.02, 0.05, 0.15, 0.225, 0.30, 0.375, 0.45, 0.525, 0.60, 0.70, 0.8, 0.9, 1.0]
    
    for d,date in enumerate(run_date):
        date_outpath = config.forecast_out_path+'20{0}/png/'.format(
                    date)
    
        if not os.path.exists(date_outpath):
            os.makedirs(date_outpath)
    
        filtered_forecast =  gaussian_filter(forecast[d],smoothing,mode='constant')
        
        m.drawcoastlines()
        m.drawstates()
        m.drawcountries()
        m.fillcontinents(color='gray',alpha=0.2)
                
        plt.contourf(x1,y1,filtered_forecast,cmap=cmap,levels=levels)
        cbar = plt.colorbar(orientation="horizontal", shrink=0.7, fraction=0.05, pad=0.02)
        cbar.set_ticks([0.02, 0.05, 0.15, 0.30, 0.45, 0.60, 0.80, 1.0])
        cbar.set_ticklabels([2,5,15,30,45,60,80,100])
    
        plt.title("{0} {5} Calibrated >{1}mm NMEP \n Valid {2}, Hours: {3}-{4}UTC".format(
            config.ensemble_name,
            size,
            date,
            start_hour%24,end_hour%24,
            config.target_data_names.upper()),
            fontweight="bold",
            fontsize=14)
        filename = date_outpath + "{0}_{6}_Hail_{1}_Cali_NMEP_{2}mm_{3}_Hours_{4}-{5}.png".format(
                                                        config.ensemble_name,
                                                        target_dataset,
                                                        size,
                                                        date.strftime(config.run_date_format),
                                                        start_hour,end_hour,config.forecast_model_names)
    
        plt.savefig(filename,bbox_inches="tight", dpi=300)
        print("Writing to " + filename)
        plt.close()
    
    return
    
def output_grib2(forecast,proj_dict,grid_dict,start_hour,end_hour,
                stride,size,run_date,target_dataset,smoothing,config):
    """
    Writes out grib2 files for given probability forecast. 

    Args: 
        forecast (list): Generated calibrated NMEP forecasts
        proj_dict (dict): Projection information of forecasts
        grid_dict (dict): Grid information of forecasts
        start_hour (int): Beginning hour of chosen forecast period
        end_hour (int): Ending hour of chosen forecast period
        stride (int): Smoohing factor
        size (int): Hail size threshold
        run_date (dataframe): Valid date for forecast
        target_dataset (str): Name of the dataset being calibrated towards
    """
    for d,date in enumerate(run_date):
        date_outpath = config.forecast_out_path+'20{0}/grib/'.format(
                    date.strftime(config.run_date_format))
    
        if not os.path.exists(date_outpath):
            os.makedirs(date_outpath)

        lscale = 1e6
        grib_id_start = [7, 0, 14, 14, 2]
   
        filtered_forecast =  gaussian_filter(forecast[d],smoothing,mode='constant')
        
        gdsinfo = np.array([0, np.product(filtered_forecast.shape[-2:]), 0, 0, 30], dtype=np.int32)
    
        lon_0 = proj_dict["lon_0"]
        sw_lon = grid_dict["sw_lon"]
        
        if lon_0 < 0:
            lon_0 += 360
        if sw_lon < 0:
            sw_lon += 360

        gdtmp1 = [1, 0, proj_dict['a'], 0, float(proj_dict['a']), 0, float(proj_dict['b']),
            filtered_forecast.shape[-1], filtered_forecast.shape[-2], grid_dict["sw_lat"] * lscale,
            sw_lon * lscale, 0, proj_dict["lat_0"] * lscale,
            lon_0 * lscale,
            grid_dict["dx"] * 1e3 * stride, grid_dict["dy"] * 1e3 * stride, 0b00000000, 0b01000000,
            proj_dict["lat_1"] * lscale,
            proj_dict["lat_2"] * lscale, -90 * lscale, 0]
        pdtmp1 = np.array([1, 31, 4, 0, 31, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1], dtype=np.int32)
        drtmp1 = np.array([0, 0, 4, 8, 0], dtype=np.int32)
        time_list = list(date.utctimetuple()[0:6])
        grib_objects = Grib2Encode(0, np.array(grib_id_start + time_list + [2, 1], dtype=np.int32))
        grib_objects.addgrid(gdsinfo, gdtmp1)
        pdtmp1[8] = end_hour
        pdtmp1[-2] = 0
        grib_objects.addfield(1, pdtmp1, 0, drtmp1, filtered_forecast)
        grib_objects.end()
        filename = date_outpath + "{0}_{6}_Hail_{1}_Cali_NMEP_{2}mm_{3}_Hours_{4}-{5}.grib2".format(
                                                        config.ensemble_name,
                                                        target_dataset,
                                                        size,
                                                        date.strftime(config.run_date_format),
                                                        start_hour,end_hour,config.forecast_model_names)
        print("Writing to " + filename  )
        
        grib_file = open(filename, "wb")
        grib_file.write(grib_objects.msg)
        grib_file.close()

    return

    
def output_netcdf(forecast,proj_dict,grid_dict,start_hour,end_hour,
                stride,size,run_date,target_dataset,smoothing,config):
    """
    Writes out netCDF4 files for given probability forecast. 

    Args: 
        forecast (list): Generated calibrated NMEP forecasts
        proj_dict (dict): Projection information of forecasts
        grid_dict (dict): Grid information of forecasts
        start_hour (int): Beginning hour of chosen forecast period
        end_hour (int): Ending hour of chosen forecast period
        stride (int): Smoohing factor
        size (int): Hail size threshold
        run_date (dataframe): Valid date for forecast
        target_dataset(str): Name of the dataset being calibrated towards
    """
    for d,date in enumerate(run_date):
        date_outpath = config.forecast_out_path+'20{0}/netcdf/'.format(
                    date)
    
        if not os.path.exists(date_outpath):
            os.makedirs(date_outpath)
        
        map_data = make_proj_grids(proj_dict,grid_dict)
        lons = map_data["lon"]
        lats = map_data["lat"]
        
        filtered_forecast =  gaussian_filter(forecast[d],smoothing,mode='constant')
    
        filename = date_outpath + "{0}_{6}_Hail_{1}_Cali_NMEP_{2}mm_{3}_Hours_{4}-{5}.nc".format(
                                                        config.ensemble_name,
                                                        target_dataset,
                                                        size,
                                                        date,
                                                        start_hour,end_hour,config.forecast_model_names)

    
        out_file = Dataset(filename, "w")
        out_file.createDimension("x", filtered_forecast.shape[0])
        out_file.createDimension("y", filtered_forecast.shape[1])
        out_file.createVariable("Longitude", "f4", ("x", "y"))
        out_file.createVariable("Latitude", "f4",("x", "y"))
        out_file.createVariable("Data", "f4", ("x", "y"))
        out_file.variables["Longitude"][:,:] = lons
        out_file.variables["Latitude"][:,:] = lats
        out_file.variables["Data"][:,:] = filtered_forecast
        out_file.projection = proj_dict["proj"]
        out_file.lon_0 = proj_dict["lon_0"]
        out_file.lat_0 = proj_dict["lat_0"]
        out_file.lat_1 = proj_dict["lat_1"]
        out_file.lat_2 = proj_dict["lat_2"]
        out_file.close()
       
        print("Writing to " + filename)
    return


if __name__ == "__main__":
    main()
