#!/usr/bin/env python
from hagelslag.data.HailForecastGrid import HailForecastGrid
from hagelslag.util.make_proj_grids import * 
import matplotlib
matplotlib.use('agg')
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import argparse
import os

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--run", help="Date of the model run.")
    parser.add_argument("-s", "--start", help="Start Date of the model run time steps.")
    parser.add_argument("-e", "--end", help="End Date of the model run time steps.")
    parser.add_argument("-n", "--ens", help="Name of the ensemble.")
    parser.add_argument("-a", "--map_file", help="Map file")      
    parser.add_argument("-m", "--model", help="Name of the machine learning model")
    parser.add_argument("-b", "--members", help="Comma-separated list of members.")
    parser.add_argument("-v", "--var", default="hail", help="Variable being plotted.")
    parser.add_argument("-g", "--grib", type=int, default=1, help="GRIB message number.")
    parser.add_argument("-p", "--path", help="Path to GRIB files")
    parser.add_argument("-o", "--out", help="Path where figures are saved.")
    parser.add_argument("-y", "--hourly", action="store_true", help="If False will not output hourly forecasts.")
    parser.add_argument("-t", "--ens_max_out", action="store_true", help="Generate Ensemble Maximum image files.")
    parser.add_argument("-u", "--nep_out", action="store_true", help="Generate NEP image files.")
    parser.add_argument("-f", "--subplot_out", action="store_true", help="Generate subplot png files.")

    args = parser.parse_args()
    run_date = pd.Timestamp(args.run).to_pydatetime()
    map_file = args.map_file
    start_date = pd.Timestamp(args.start).to_pydatetime()
    end_date = pd.Timestamp(args.end).to_pydatetime()
    ensemble_members = args.members.split(",")
    neighbor_radius = 42.0
    neighbor_smoothing = 14 #number of gridpoints 
    thresholds = np.array([25, 50])
    stride = 1
   
    print("Loading data")
    forecast_grid = load_hail_forecasts(run_date, start_date, end_date, args.ens, args.model,
                                        ensemble_members, args.var, args.grib, args.path)

    if forecast_grid.data is None:
        print('No grib hail forecasts found')
    else:
        print("Get basemaps")
        bmap_sub = get_sub_basemap(forecast_grid, stride)
        bmap_full = get_full_basemap(forecast_grid)
        if args.ens_max_out:
            print("Daily plots: neighborhood probability")
            plot_period_ensemble_max(forecast_grid, bmap_full, 0, 24, args.out)
            if args.hourly: 
                print("Hourly plots: ensemble_max")
                for hour in range(4,21):
                    plot_period_ensemble_max(forecast_grid, bmap_full, hour,(hour+4), args.out)
                    print()
        if args.nep_out:
            print("Daily plots: neighborhood probability")
            plot_period_neighborhood_probability(forecast_grid, bmap_sub, 0, 24, 
                    neighbor_radius, neighbor_smoothing, thresholds, stride, args.out)
            if args.hourly: 
                print("Hourly plots: neighborhood probability")
                for hour in range(4,21):
                    plot_period_neighborhood_probability(forecast_grid, bmap_sub, hour,(hour+4),
                        neighbor_radius, neighbor_smoothing, thresholds, stride, args.out)
                    print()
            print()
        if args.subplot_out:
            plot_subplots(forecast_grid,bmap_full,0,24,args.out,"ens_max",ensemble_members)
    return


def load_hail_forecasts(run_date, start_date, end_date, ensemble_name, ml_model, 
                        members, variable,message_number, path):
    """
    Load hail forecasts for given dates and ensemble 
    
    Args:
        run_date (datetime.datetime): Date of the initial time of the model run
        start_date (datetime.datetime): Date of the initial forecast time being loaded
        end_date (datetime.datetime): Date of the final forecast time being loaded
        ensemble_name (str): Name of the NWP ensemble being used
        ml_model (str): Name of the machine learning model being loaded
        members (list of str): Name of ensemble members
        variable (str): Name of the machine learning model variable being forecast
        message_number (int): Field in the GRIB2 file to load. The first field in the file has message number 1.
        path (str): Path to top-level GRIB2 directory. Assumes files are stored in directories by run_date
    
    """
    forecast_grid = HailForecastGrid(run_date, start_date, end_date, ensemble_name, ml_model, members,
                                     variable, message_number, path)
    forecast_grid.load_data()
    return forecast_grid

def get_sub_basemap(forecast_grid, stride=14):
    bmap = Basemap(projection=forecast_grid.projparams["proj"], resolution="l",
                   rsphere=forecast_grid.projparams["a"],
                   lon_0=forecast_grid.projparams["lon_0"],
                   lat_0=forecast_grid.projparams["lat_0"],
                   lat_1=forecast_grid.projparams["lat_1"],
                   lat_2=forecast_grid.projparams["lat_2"],
                   llcrnrlon=forecast_grid.lon[0, 0],
                   llcrnrlat=forecast_grid.lat[0, 0],
                   urcrnrlon=forecast_grid.lon[::stride, ::stride][-1, -1],
                   urcrnrlat=forecast_grid.lat[::stride, ::stride][-1, -1])
    return bmap


def get_full_basemap(forecast_grid):
    bmap = Basemap(projection=forecast_grid.projparams["proj"], resolution="l",
                   rsphere=forecast_grid.projparams["a"],
                   lon_0=forecast_grid.projparams["lon_0"],
                   lat_0=forecast_grid.projparams["lat_0"],
                   lat_1=forecast_grid.projparams["lat_1"],
                   lat_2=forecast_grid.projparams["lat_2"],
                   llcrnrlon=forecast_grid.lon[0, 0],
                   llcrnrlat=forecast_grid.lat[0, 0],
                   urcrnrlon=forecast_grid.lon[-1, -1],
                   urcrnrlat=forecast_grid.lat[-1, -1])
    return bmap
        
    plot_period_ensemble_max(0,24,forecast_grid.start_date.strftime("%d %b %Y"),
                            forecast_grid,bmap_full, args.out)


def plot_period_ensemble_max(forecast_grid, bmap, start_time, end_time, out_path,
                            figsize=(10, 6), contours=np.concatenate([[1] + np.arange(5, 80, 5)]),
                            cmap="inferno"):
    """
    Plot different periods of ensemble maximum hail size forecasts
    
    Args:
        forecast_grid (): data from input forecast grib file
        bmap (obj): Basemap object
        start_time (str): Beginning hour between 0-24 period for output grib files
        end_time (str): End hour between 0-24 period for output grib files
        out_path (str): Path to where output png files are stored
    Returns:
        PNG files of ensemble maximum forecasts over specified time period 
    """
    cmap = plt.get_cmap("Paired")
    colors = cmap(np.linspace(0, 0.5, cmap.N // 2))
    cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list("Lower Half Paired", colors)
    ens_max = forecast_grid.data.max(axis=0)
    ens_max = ens_max[start_time:end_time,:,:].max(axis=0)
    plt.figure(figsize=figsize)
    bmap.drawstates()
    bmap.drawcountries()
    bmap.drawcoastlines()
    x, y = bmap(forecast_grid.lon, forecast_grid.lat)
    plt.contourf(x, y, ens_max, contours, cmap=cmap2, extend="max")
                             
    date_outpath=out_path+'{0}/png/'.format(forecast_grid.start_date.strftime("%Y%m%d"))
    if not os.path.exists(date_outpath):
        os.makedirs(date_outpath)
                             
    plt.title("{0} {1} Ensemble Maximum {2} (mm), Valid {3} {4}-{5} UTC".format(forecast_grid.ensemble_name,
                                                        forecast_grid.ml_model.replace("-", " "),
                                                        forecast_grid.variable.capitalize(),
                                                        forecast_grid.start_date.strftime("%d %b %Y"),
                                                        ((start_time+12)%24),((end_time+12)%24)),
            fontweight="bold",
            fontsize=12)
    plt.colorbar(orientation="horizontal", shrink=0.7, fraction=0.05, pad=0.02)
    ens_max_file = date_outpath + "{0}_{1}_ens_max_{2}_time_{3}_{4}.png".format(
                                                        forecast_grid.ensemble_name,
                                                        forecast_grid.ml_model,
                                                        forecast_grid.start_date.strftime("%y%m%d"),
                                                        (start_time+12),(end_time+12))
    plt.savefig(ens_max_file,bbox_inches="tight", dpi=300)
    plt.close()
    return


def plot_period_neighborhood_probability(forecast_grid, bmap, start_time, end_time, radius, 
                                        smoothing, thresholds, stride, out_path, figsize=(10, 6)):
    """
    Plot different periods of neighborhood maximum ensemble probability forecasts
    
    Args:
        forecast_grid (): Data from input forecast grib file
        bmap (obj): Basemap object
        start_time (str): Beginning hour between 0-24 period for output grib files
        end_time (str): End hour between 0-24 period for output grib files
        radius (int): Circular radius from each point in km
        smoothing (int): Width of gaussian smoother in number of grid points
        threshold (int): Hail size threshold
        stride (int): Number of grid points to skip for reduced neighborhood grid
        out_path (str): Path to where output png files are stored
    Returns:
        PNG files of neighborhood maximum ensemble probability forecasts 
            over specified time period 
    """
    for threshold in thresholds:
        plt.figure(figsize=figsize)
        neighbor_prob = forecast_grid.period_neighborhood_probability(
                    radius,smoothing,threshold,stride,start_time,end_time)
        
        n_lon = forecast_grid.lon[::stride, ::stride]
        n_lat = forecast_grid.lat[::stride, ::stride]
        neighbor_prob = neighbor_prob.mean(axis=0)
        
        date_outpath=out_path+'{0}/png/'.format(forecast_grid.start_date.strftime("%Y%m%d"))
        if not os.path.exists(date_outpath):
            os.makedirs(date_outpath)

        n_x, n_y = bmap(n_lon, n_lat)
        bmap.drawstates()
        bmap.drawcoastlines()
        bmap.drawcountries()
        
        cmap = matplotlib.colors.ListedColormap(['#DBC6BD','#AD8877','#FCEA8D', 
                                                'gold','#F76E67','#F2372E',
                                                '#F984F9','#F740F7','#AE7ADD','#964ADB',
                                                '#99CCFF', '#99CCFF','#3399FF'])
                
        levels = [0.02,0.05,0.15, 0.225, 0.30, 0.375, 0.45, 0.525, 0.60, 0.70, 0.8, 0.9, 1.0]
           
        plt.contourf(n_x, n_y, neighbor_prob,extend="max", cmap=cmap,levels=levels)
        cbar = plt.colorbar(orientation="horizontal", shrink=0.7, fraction=0.05, pad=0.02)
        cbar.set_ticks([0.02,0.05,0.15, 0.30, 0.45, 0.60, 0.80, 1.0])
        cbar.set_ticklabels([2,5,15,30,45,60,80,100])

        plt.title("{0} {1} Ens. Max. Neighbor Prob. of {2}>{3:d} mm\nR={4:d} km, $\sigma$={5:d} km, Valid {6}".format(
                                                        forecast_grid.ensemble_name,
                                                        forecast_grid.ml_model.replace("-", " "),
                                                        forecast_grid.variable.capitalize(),
                                                        int(threshold),
                                                        int(radius),
                                                        int(smoothing*3),
                                                        forecast_grid.start_date.strftime("%d %b %Y")),
                                                        fontweight="bold",
                                                        fontsize=10)
        filename = date_outpath + "{0}_{1}_NMEP_{2:d}_r_{3:d}_s_{4:d}_{5}_time_{6}_{7}.png".format(
                                                        forecast_grid.ensemble_name,
                                                        forecast_grid.ml_model,
                                                        int(threshold),
                                                        int(radius),
                                                        int(smoothing*3),
                                                        forecast_grid.start_date.strftime("%y%m%d"),
                                                        (start_time+12),(end_time+12))
        plt.savefig(filename,bbox_inches="tight", dpi=300)
        plt.close()
    return

def plot_subplots(forecast_grid, bmap, start_time, end_time, out_path, plot_mode, ensemble_members,
                  contours=np.concatenate([[1] + np.arange(5, 80, 5)]), cmap="inferno"):
    """
    Plot different periods of individual ensemble member forecasts
    
    Args:
        forecast_grid (): Data from input forecast grib file
        bmap (obj): Basemap object
        start_time (str): Beginning hour between 0-24 period for output grib files
        end_time (str): End hour between 0-24 period for output grib files
        out_path (str): Path to where output png files are stored
        plot_mode (str): Currently only supports ensemble maximum subplots
        ensemble_members (list of str): Name of ensemble members
    Returns:
        PNG files of subplot member maximum size forecasts 
            over specified time period 
    """
    cmap = plt.get_cmap("Paired")
    colors = cmap(np.linspace(0, 0.5, cmap.N // 2))
    cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list("Lower Half Paired", colors)
    members_in_half = round(forecast_grid.data.shape[0]/2.0)
    seperate_images = [list(range(0,members_in_half)), list(range(members_in_half,forecast_grid.data.shape[0]))]
    for images in seperate_images:
        f = plt.figure(figsize=(25,25))
        plt.subplots_adjust(left=0.2, bottom=0.1, right=0.9, top=0.96, wspace=0.01, hspace=0.05)
        for n, mem in enumerate(images,1):
            ax = f.add_subplot(4, 2, n)
            bmap.drawstates()
            bmap.drawcountries()
            bmap.drawcoastlines()
            x, y = bmap(forecast_grid.lon, forecast_grid.lat)
            data = ax.contourf(x, y, forecast_grid.data[mem,start_time:end_time,:,:].max(axis=0),
                        contours, cmap=cmap2, extend="max")
            ax.set_title(ensemble_members[mem])
        plt.colorbar(data, orientation='horizontal',shrink=0.7, fraction=0.05, pad=0.02)
        f.suptitle("{0} {1} {2} {3} (mm), Valid {3} {4}-{5} UTC".format(
                                                        forecast_grid.ensemble_name,
                                                        forecast_grid.ml_model.replace("-", " "),
                                                        plot_mode,
                                                        forecast_grid.variable.capitalize(),
                                                        forecast_grid.start_date.strftime("%d %b %Y"),
                                                        ((start_time+12)%24),((end_time+12)%24)),
              fontweight="bold",
              fontsize=10)

        date_outpath=out_path+'{0}/png/'.format(forecast_grid.start_date.strftime("%Y%m%d"))
        if not os.path.exists(date_outpath):
            os.makedirs(date_outpath)

        subplot_file = date_outpath + "{0}_{1}_{2}_{3}_time_{4}_{5}_{6}.png".format(
                                                        forecast_grid.ensemble_name,
                                                        forecast_grid.ml_model,
                                                        size_dist_mode,
                                                        plot_mode,
                                                        forecast_grid.start_date.strftime("%y%m%d"),
                                                        (start_time+12),(end_time+12),
                                                        images[0])
        plt.savefig(subplot_file,bbox_inches="tight", dpi=300)
        plt.close()
    return
                                     
if __name__ == "__main__":
    main()
