#!/usr/bin/env python
from hagelslag.data.HailForecastGrid import HailForecastGrid
from hagelslag.util.make_proj_grids import * 
from netCDF4 import Dataset
import pandas as pd
import numpy as np
import argparse
import os

try: 
    from ncepgrib2 import Grib2Encode, dump
    grib_support = True
except ImportError("ncepgrib2 not available"):
    grib_support = False   

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--run", help="Date of the model run.")
    parser.add_argument("-s", "--start", help="Start Date of the model run time steps.")
    parser.add_argument("-e", "--end", help="End Date of the model run time steps.")
    parser.add_argument("-n", "--ens", help="Name of the ensemble.")
    parser.add_argument("-a", "--map_file", help="Map file")      
    parser.add_argument("-m", "--model", help="Name of the machine learning model")
    parser.add_argument("-b", "--members", help="Comma-separated list of members.")
    parser.add_argument("-v", "--var", default="hail", help="Variable being plotted.")
    parser.add_argument("-g", "--grib", type=int, default=1, help="GRIB message number.")
    parser.add_argument("-p", "--path", help="Path to GRIB files")
    parser.add_argument("-o", "--out", help="Path where figures are saved.")
    parser.add_argument("-i", "--grib_out", action="store_true", help="Generate grib2 files")
    parser.add_argument("-c", "--netcdf_out", action="store_true", help="Generate netcdf files.")
    parser.add_argument("-l", "--calibration", required=False, default=True, help="If True will not smooth output files.")
    parser.add_argument("-y", "--hourly", action="store_true", help="If False will not output hourly forecasts.")

    args = parser.parse_args()
    run_date = pd.Timestamp(args.run).to_pydatetime()
    start_date = pd.Timestamp(args.start).to_pydatetime()
    end_date = pd.Timestamp(args.end).to_pydatetime()
    map_file = args.map_file
    ensemble_members = args.members.split(",")
    calibrate = args.calibration

    neighbor_radius = 42.0
    neighbor_smoothing = 14 #number of gridpoints 
    thresholds = np.array([25, 50])
    stride = 1
   
    print("Loading data")
    forecast_grid = load_hail_forecasts(run_date, start_date, end_date, args.ens, args.model,
                                            ensemble_members, args.var, args.grib, args.path) 
    if forecast_grid.data is None:
        print('No grib hail forecasts found')
    else:
        if args.grib_out:
            print("Output grib2 files")
            output_grib2_files(forecast_grid,0,24,run_date,map_file,args.out,
                                neighbor_radius,neighbor_smoothing,thresholds,stride,calibrate)
            print()
            if args.hourly:
                for hour in range(4,21):
                    output_grib2_files(forecast_grid,hour,(hour+1),run_date,map_file,args.out,
                                   neighbor_radius,neighbor_smoothing,thresholds,stride,calibrate)                             
                    print()
        if args.netcdf_out:
            print("Output netcdf files")
            output_netcdf_files(forecast_grid,0,24,args.out,neighbor_radius,
                                neighbor_smoothing,thresholds,stride,map_file,calibrate)
            print()
            if args.hourly:
                for hour in range(4,21):
                    output_netcdf_files(forecast_grid,hour,(hour+4),args.out,neighbor_radius,
                                     neighbor_smoothing,thresholds,stride,map_file,calibrate)
                    print()
        print()
    return


def load_hail_forecasts(run_date,start_date,end_date,ensemble_name, 
                        ml_model,members,variable,message_number,path):
    """
    Load hail forecasts for given dates and ensemble 
    
    Args:
        run_date (datetime.datetime): Date of the initial time of the model run
        start_date (datetime.datetime): Date of the initial forecast time being loaded
        end_date (datetime.datetime): Date of the final forecast time being loaded
        ensemble_name (str): Name of the NWP ensemble being used
        ml_model (str): Name of the machine learning model being loaded
        members (list of str): Name of ensemble members
        variable (str): Name of the machine learning model variable being forecast
        message_number (int): Field in the GRIB2 file to load. The first field in the file has message number 1.
        path (str): Path to top-level GRIB2 directory. Assumes files are stored in directories by run_date
    
    """

    forecast_grid = HailForecastGrid(run_date, start_date, end_date, ensemble_name, ml_model, members,
                                     variable, message_number, path)
    forecast_grid.load_data()
    return forecast_grid

def output_grib2_files(forecast_grid,start_time,end_time,run_date,map_file,out_path,
                       radius,smoothing,thresholds,stride,calibrate):
    """
    Calculates the neighborhood probability and ensemble maximum hail
    over a specified time period, and outputs grib files.
    
    Args:
        forecast_grid (HailForecastGrid): data from input forecast grib file
        start_time (str): Beginning hour between 0-24 period for output grib files
        end_time (str): End hour between 0-24 period for output grib files
        run_date (datetime.datetime): Date of the initial time of the model run
        map_file (str): Map associated with the input/output grib files
        out_path (str): Path to where output grib files are stored
        radius (int): circular radius from each point in km
        smoothing (int): width of Gaussian smoother in number of grid points
        threshold (int): Hail size threshold
        stride (int): number of grid points to skip for reduced neighborhood grid
        calibrate (bool): If true, will not smooth forecast data    
    
    Returns:
            Grib2 files of neighborhood probability and ensemble maximum 
            over specified time period 
        """
    date_outpath=out_path+'{0}/grib/'.format(forecast_grid.start_date.strftime("%Y%m%d"))
    
    if not os.path.exists(date_outpath):
        os.makedirs(date_outpath)
    
    if calibrate is True:
        ens_max = forecast_grid.data.max(axis=0)
        ens_max = ens_max[start_time:end_time,:,:].max(axis=0)
        grib = write_grib2_file(ens_max,forecast_grid,(start_time+12),(end_time+12),
                            map_file,run_date,date_outpath,"Ensemble-Maximum",stride)   
    else:
        print()
        for threshold in thresholds:
            neighbor_prob = forecast_grid.period_neighborhood_probability(radius,smoothing, threshold, 
                                                                      stride,start_time,end_time)
            neighbor_prob = neighbor_prob.mean(axis=0)
            grib = write_grib2_file(neighbor_prob,forecast_grid,(start_time+12),(end_time+12),
                            map_file,run_date,date_outpath,"NMEP_{0}mm".format(threshold),stride)
    
    return                                                         

def write_grib2_file(data,forecast_grid,start_time,end_time,
                    map_file,run_date,out_path,plot_mode,stride):
    """
    Writes out grib2 files for given Ensemble Maximum and Neighborhood Probability numpy array data. 

    Args:
        data (ndarray): forecasts over given time period
        forecast_grid (HailForecastGrid): data from input forecast grib file
        start_time (str): Beginning hour between 12-36 period for output grib files
        end_time (str): End hour between 12-36 period for output grib files
        map_file (str): Map associated with the input/output grib files
        run_date (datetime.datetime): Date of the initial time of the model run
        out_path (str): Path to where output grib files are stored
        plot_mode (str): Type of plot, either ensemble max or probability
        stride (int): number of grid points to skip for reduced neighborhood grid
    """
    
    if map_file[-3:] == "map":                                  
        proj_dict, grid_dict = read_arps_map_file(map_file)
    else:                                               
        proj_dict, grid_dict = read_ncar_map_file(map_file)   

    lscale = 1e6
    grib_id_start = [7, 0, 14, 14, 2]
    gdsinfo = np.array([0, np.product(data.shape[-2:]), 0, 0, 30], dtype=np.int32)
    lon_0 = proj_dict["lon_0"]
    sw_lon = grid_dict["sw_lon"]
    if lon_0 < 0:
        lon_0 += 360
    if sw_lon < 0:
        sw_lon += 360

    gdtmp1 = [1, 0, proj_dict['a'], 0, float(proj_dict['a']), 0, float(proj_dict['b']),
            data.shape[-1], data.shape[-2], grid_dict["sw_lat"] * lscale,
            sw_lon * lscale, 0, proj_dict["lat_0"] * lscale,
            lon_0 * lscale,
            grid_dict["dx"] * 1e3 * stride, grid_dict["dy"] * 1e3 * stride, 0b00000000, 0b01000000,
            proj_dict["lat_1"] * lscale,
            proj_dict["lat_2"] * lscale, -90 * lscale, 0]
    pdtmp1 = np.array([1, 31, 4, 0, 31, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1], dtype=np.int32)
    drtmp1 = np.array([0, 0, 4, 8, 0], dtype=np.int32)
    time_list = list(run_date.utctimetuple()[0:6])
    grib_objects = Grib2Encode(0, np.array(grib_id_start + time_list + [2, 1], dtype=np.int32))
    grib_objects.addgrid(gdsinfo, gdtmp1)
    pdtmp1[8] = end_time
    pdtmp1[-2] = 0
    grib_objects.addfield(1, pdtmp1, 0, drtmp1, data)
    grib_objects.end()
    filename = out_path + "{0}_Hail_{1}_{2}_{3}_Hours_{4}-{5}.grib2".format(
                                                        forecast_grid.ensemble_name,
                                                        forecast_grid.ml_model,
                                                        plot_mode,
                                                        forecast_grid.start_date.strftime("%y%m%d"),
                                                        start_time,end_time)
    
    print("Writing to " + filename)
    grib_file = open(filename, "wb")
    grib_file.write(grib_objects.msg)
    grib_file.close()

    return

def output_netcdf_files(forecast_grid,start_time,end_time,out_path,radius,
                        smoothing,thresholds,stride,map_file,calibrate):
    """
    Calculates the neighborhood probability over a specified time period, and outputs netcdf files.
    Ensemble Maximum hail could be added in the future. 
     
    Args:
        forecast_grid (HailForecastGrid): data from input forecast grib file
        start_time (str): Beginning hour between 0-24 period for output grib files
        end_time (str): End hour between 0-24 period for output grib files
        out_path (str): Path to where output grib files are stored
        radius (int): circular radius from each point in km
        smoothing (int): width of Gaussian smoother in number of grid points
        threshold (int): Hail size threshold
        stride (int): number of grid points to skip for reduced neighborhood grid
        map_file (str): Map associated with the input/output grib files
        calibrate (bool): If true, will not smooth forecast data    
    Returns:
            Netcdf files of neighborhood probability over specified time period 
    """
    if calibrate is True:
        smoothing = 0

    date_outpath = out_path+'{0}/netcdf/'.format(forecast_grid.start_date.strftime("%Y%m%d"))
    if not os.path.exists(date_outpath):
        os.makedirs(date_outpath)
    
    n_lon = forecast_grid.lon[::stride, ::stride]
    n_lat = forecast_grid.lat[::stride, ::stride]

    for threshold in thresholds:
        neighbor_prob = forecast_grid.period_neighborhood_probability(radius,smoothing, threshold,
                                                                     stride,start_time,end_time)
        if calibrate is True:
            plot_mode =  "NMEP_unsmoothed_{0}".format(threshold)
        else:
            plot_mode =  "NMEP_smoothed_{0}".format(threshold)
        neighbor_prob = neighbor_prob.mean(axis=0)
        netcdf_file = write_netcdf_file(neighbor_prob,forecast_grid,(start_time+12),(end_time+12),date_outpath, 
                                    n_lon, n_lat, plot_mode,map_file)
    return

def write_netcdf_file(neighbor_prob,forecast_grid,start_time,end_time,out_path,
                    data_lon,data_lat,plot_mode,map_file):
    """
    Writes out grib2 files for given Ensemble Maximum and Neighborhood Probability numpy array data. 

    Args:
        neighbor_prob (ndarray): Neighborhood probabilities over given time period
        forecast_grid (HailForecastGrid): data from input forecast grib file
        start_time (str): Beginning hour between 12-36 period for output grib files
        end_time (str): End hour between 12-36 period for output grib files
        out_path (str): Path to where output grib files are stored
        data_lon (ndarray): Array of longitudes for forecast data
        data_lat (ndarray): Array of latitudes for forecast data
        plot_mode (str): Type of plot, either ensemble max or probability
        map_file (str): Map associated with the input/output grib files
    """
    if map_file[-3:] == "map":                                  
        proj_dict, grid_dict = read_arps_map_file(map_file)
    else:                                               
        proj_dict, grid_dict = read_ncar_map_file(map_file)   
    
    out_filename = out_path + "{0}_Hail_{1}_{2}_{3}_Hours_{4}-{5}.nc".format(
                                                        forecast_grid.ensemble_name,
                                                        forecast_grid.ml_model,
                                                        plot_mode,
                                                        forecast_grid.start_date.strftime("%y%m%d"),
                                                        start_time,end_time)
    out_file = Dataset(out_filename, "w")
    out_file.createDimension("x", neighbor_prob.shape[0])
    out_file.createDimension("y", neighbor_prob.shape[1])
    out_file.createVariable("Longitude", "f4", ("x", "y"))
    out_file.createVariable("Latitude", "f4",("x", "y"))
    out_file.createVariable("Data", "f4", ("x", "y"))
    out_file.variables["Longitude"][:,:] = data_lon
    out_file.variables["Latitude"][:,:] = data_lat
    out_file.variables["Data"][:,:] = neighbor_prob
    out_file.projection = proj_dict["proj"]
    out_file.lon_0 = proj_dict["lon_0"]
    out_file.lat_0 = proj_dict["lat_0"]
    out_file.lat_1 = proj_dict["lat_1"]
    out_file.lat_2 = proj_dict["lat_2"]
    out_file.close()

    print("Writing to " + out_filename)

    return

if __name__ == "__main__":
    main()
