#!/usr/bin/env python
import numpy as np
from hagelslag.processing import extract_storm_patches, label_storm_objects
import pygrib
from multiprocessing import Pool
import argparse
from os.path import join, exists
import pandas as pd
import traceback
from scipy.ndimage import gaussian_filter
import pkg_resources
from netCDF4 import Dataset, date2num


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-s", "--start_date", help="Start Run Date")
    parser.add_argument("-e", "--end_date", help="End Run Date")
    parser.add_argument("-t", "--start_hour", type=int, default=1, help="Start Forecast Hour")
    parser.add_argument("-n", "--end_hour", type=int, default=48, help="End Forecast Hour")
    parser.add_argument("-l", "--label_var", type=int, help="Variable Number used for storm labeling")
    parser.add_argument("-v", "--storm_vars", help="Comma-separated list of storm variable numbers")
    parser.add_argument("-a", "--env_vars", help="Comma-separated list of environment variable numbers")
    parser.add_argument("-i", "--in_path", help="Path to NCAR Ensemble GRIB2 files.")
    parser.add_argument("-o", "--out_path", help="Path to output netCDF files.")
    parser.add_argument("-r", "--patch_radius", type=int, default=32,
                        help="Radius of the patch being extracted around each storm centroid.")
    parser.add_argument("-f", "--finder_method", default="ew",
                        help="Method used to find objects. Options: 'ew' or 'hyst'.")
    parser.add_argument("--min_intensity", default=10, type=int, help="Minimum object intensity")
    parser.add_argument("--max_intensity", default=100, type=int, help="Maximum object intensity")
    parser.add_argument("--min_area", default=10, type=int, help="Minimum area")
    parser.add_argument("--max_area", default=100, type=int, help="Maximum area")
    parser.add_argument("--max_range", default=100, type=int, help="Maximum range")
    parser.add_argument("--increment", default=1, type=int, help="Quantization Increment")
    parser.add_argument("--smoother_sigma", default=1, type=int, help="Gaussian smoother standard deviation")
    parser.add_argument("-p", '--proc', type=int, default=1, help="Number of processors")
    args = parser.parse_args()
    pool = Pool(args.proc)
    run_dates = pd.DatetimeIndex(start=args.start_date, end=args.end_date, freq="1D")
    members = np.arange(1, 11)
    storm_vars = np.array(args.storm_vars.split(","), dtype=int)
    print(storm_vars)
    env_vars = np.array(args.env_vars.split(","), dtype=int)
    print(env_vars)
    for run_date in run_dates:
        for member in members:
            proc_args = (run_date, member, args.start_hour, args.end_hour, args.label_var, storm_vars,
                    env_vars, args.in_path, args.out_path, args.patch_radius, args.finder_method, args.min_intensity,
                    args.max_intensity, args.min_area, args.max_area, args.max_range, args.increment,
                    args.smoother_sigma)
            pool.apply_async(extract_ncar_member_storms, proc_args)
    pool.close()
    pool.join()
    return


def extract_ncar_member_storms(run_date, member, start_hour, end_hour, label_var, storm_vars, env_vars,
                               data_path, out_path, patch_radius, object_finder_method, min_intensity, max_intensity,
                               min_area, max_area, max_range, increment, smoother_sigma):
    """


    Args:
        run_date:
        member:
        start_hour:
        end_hour:
        label_var:
        storm_vars:
        env_vars:
        data_path:
        out_path:
        patch_radius:
        object_finder_method:
        min_intensity:
        max_intensity:
        min_area:
        max_area:
        max_range:
        increment:
        smoother_sigma:

    Returns:

    """
    try:
        print("Starting {0} {1} at {2}".format(run_date.strftime("%Y-%m-%d"), member,
                                               pd.Timestamp("now").strftime("%Y-%m-%d %H:%M:%S")))
        forecast_hours = np.arange(start_hour, end_hour + 1)
        all_storm_patches = []
        all_forecast_hours = []
        all_valid_times = []
        grib_info = load_ncar_grib_info()
        us_mask = load_ncar_us_mask()
        for forecast_hour in forecast_hours:
            ncar_grib_filename = join(data_path, run_date.strftime("%Y%m%d%H"),
                                      "post", "mem_{0:d}".format(member),
                                      "ncar_3km_{0}_mem{1:d}_f{2:03d}.grb2".format(run_date.strftime("%Y%m%d%H"),
                                                                                   member,
                                                                                   forecast_hour))
            ncar_env_filename = join(data_path, run_date.strftime("%Y%m%d%H"),
                                      "post", "mem_{0:d}".format(member),
                                      "ncar_3km_{0}_mem{1:d}_f{2:03d}.grb2".format(run_date.strftime("%Y%m%d%H"),
                                                                                   member,
                                                                                   forecast_hour - 1))
            if exists(ncar_grib_filename):
                ncar_grib_file = pygrib.open(ncar_grib_filename)
                storm_values = ncar_grib_file[label_var].values
                storm_values[:patch_radius] = 0
                storm_values[:, :patch_radius] = 0
                storm_values[-patch_radius:] = 0
                storm_values[:, -patch_radius:] = 0
                storm_values *= us_mask
                print("Extracting storm objects Run: {0} Member: {1} FH: {2}".format(run_date.strftime("%Y-%m-%d"), 
                                                                                     member, forecast_hour))
                storm_label_grid = label_storm_objects(gaussian_filter(storm_values, smoother_sigma),
                                                       object_finder_method,
                                                       min_intensity, max_intensity,
                                                       min_area=min_area, max_area=max_area, max_range=max_range,
                                                       increment=increment)
                j_grid, i_grid = np.meshgrid(np.arange(storm_values.shape[1]), np.arange(storm_values.shape[0]))
                x_grid = j_grid * 3
                y_grid = i_grid * 3
                lats, lons = ncar_grib_file[label_var].latlons()
                storm_patches = extract_storm_patches(storm_label_grid, storm_values, x_grid, y_grid, [forecast_hour],
                                                      dx=3, dt=1, patch_radius=patch_radius)[0]
                print("Found {3} storm objects Run: {0} Member: {1} FH: {2}".format(run_date.strftime("%Y-%m-%d"), 
                                                                                    member, forecast_hour, len(storm_patches)))
                if len(storm_patches) > 0:
                    all_valid_times.extend([run_date + pd.Timedelta(hours=forecast_hour)] * len(storm_patches))
                    all_forecast_hours.extend([forecast_hour] * len(storm_patches))
                    for storm_patch in storm_patches:
                        storm_patch.extract_attribute_array(lats, "latitude")
                        storm_patch.extract_attribute_array(lons, "longitude")

                    print("Extracting storm vars Run: {0} Member: {1} FH: {2}".format(run_date.strftime("%Y-%m-%d"), 
                                                                                      member, forecast_hour))
                    for storm_var in storm_vars:
                        storm_var_vals = ncar_grib_file[int(storm_var)].values
                        for storm_patch in storm_patches:
                            storm_patch.extract_attribute_array(storm_var_vals, str(storm_var) + "_current")
                    print("Extracting env vars Run: {0} Member: {1} FH: {2}".format(run_date.strftime("%Y-%m-%d"), 
                                                                                    member, forecast_hour))
                    ncar_env_file = pygrib.open(ncar_env_filename)
                    for env_var in env_vars:
                        env_var_vals = ncar_env_file[int(env_var)].values
                        for storm_patch in storm_patches:
                            storm_patch.extract_attribute_array(env_var_vals, str(env_var) + "_prev")
                    ncar_env_file.close()
                    all_storm_patches.extend(storm_patches)
                ncar_grib_file.close()
            else:
                raise OSError(ncar_grib_filename + " not found")
        print("Saving storm patches {0} {1}".format(run_date.strftime("%Y-%m-%d"), member))
        out_file_name = out_path + "ncar_ens_storm_patches_{1}_mem_{0:02d}.nc".format(member,
                                                                                      run_date.strftime("%Y%m%d%H"))
        out_file = Dataset(out_file_name, mode="w", format="NETCDF4")
        out_file.createDimension("p", size=len(all_storm_patches))
        out_file.createDimension("y", size=patch_radius * 2)
        out_file.createDimension("x", size=patch_radius * 2)
        out_file.Conventions = "CF-1.6"
        out_file.title = "NCAR Ensemble Storm Patches for run {0}".format(run_date.strftime("%Y%m%d%H"))
        out_file.institution = "National Center for Atmospheric Research"
        lon_var = out_file.createVariable("longitude", "f4", dimensions=("p", "y", "x"), zlib=True)
        lon_var.long_name = "longitude"
        lon_var.units = "degrees_east"
        lat_var = out_file.createVariable("latitude", "f4", dimensions=("p", "y", "x"), zlib=True)
        lat_var.long_name = "latitude"
        lat_var.units = "degrees_north"
        p_var = out_file.createVariable("p", "u4", dimensions=("p",))
        p_var.long_name = "Storm Patch Number"
        p_var.units = ""
        p_var[:] = np.arange(len(all_storm_patches))
        y_var = out_file.createVariable("y", "f4", dimensions=("y",))
        y_var.long_name = "y-coordinate in patch space"
        y_var.units = "km"
        y_var[:] = np.arange(0, patch_radius * 2 * 3, 3)
        x_var = out_file.createVariable("x", "f4", dimensions=("x",))
        x_var.long_name = "x-coordinate in patch space"
        x_var.units = "km"
        x_var[:] = np.arange(0, patch_radius * 2 * 3, 3)
        row_var = out_file.createVariable("row", "i4", dimensions=("p", "y", "x"), zlib=True)
        row_var.long_name = "NCAR Ensemble row numbers"
        row_var.units = ""
        col_var = out_file.createVariable("column", "i4", dimensions=("p", "y", "x"), zlib=True)
        col_var.long_name = "NCAR Ensemble column numbers"
        col_var.units = ""
        fh_var = out_file.createVariable("forecast_hour", "i2", dimensions=("p",))
        fh_var.long_name = "Hours since forecast was initialized"
        fh_var.units = "hours since {0}".format(run_date.strftime("%Y-%m-%d %H:%M:%S"))
        fh_var[:] = all_forecast_hours
        valid_date_var = out_file.createVariable("valid_date", "i4", dimensions=("p",))
        valid_date_var.long_name = "Time when forecast was valid"
        valid_date_var.units = "hours since 2015-01-01 00:00:00"
        valid_date_var[:] = date2num(pd.DatetimeIndex(all_valid_times).to_pydatetime(), units=valid_date_var.units)
        run_date_var = out_file.createVariable("run_date", "i4", dimensions=("p",))
        run_date_var.long_name = "Time when model run was initialized"
        run_date_var.units = "hours since 2015-01-01 00:00:00"
        run_date_var[:] = date2num([run_date.to_pydatetime()] * len(all_storm_patches), units=run_date_var.units)
        mask_var = out_file.createVariable("mask", "u1", dimensions=("p", "y", "x"))
        mask_var.long_name = "Storm Mask (1=Storm, 0=Background)"
        mask_var.units = ""
        nc_label_var = out_file.createVariable(grib_info.loc[label_var, "var_name"], "f4",
                                               dimensions=("p", "y", "x"), zlib=True)
        nc_label_var.long_name = grib_info.loc[label_var, "long_name"]
        nc_label_var.units = grib_info.loc[label_var, "units"]
        for s, storm_patch in enumerate(all_storm_patches):
            nc_label_var[s] = storm_patch.timesteps[0]
            mask_var[s] = storm_patch.masks[0]
            lon_var[s] = storm_patch.attributes["longitude"][0]
            lat_var[s] = storm_patch.attributes["latitude"][0]
            row_var[s] = storm_patch.i[0]
            col_var[s] = storm_patch.j[0]
        for storm_var in storm_vars:
            nc_storm_var = out_file.createVariable(grib_info.loc[storm_var, "var_name"] + "_current",
                                                   "f4", dimensions=("p", "y", "x"), zlib=True)
            nc_storm_var.long_name = grib_info.loc[storm_var, "long_name"] + " from current hour"
            nc_storm_var.units = grib_info.loc[storm_var, "units"]
            nc_storm_var.coordinates = "p latitude longitude"
            for s, storm_patch in enumerate(all_storm_patches):
                nc_storm_var[s] = storm_patch.attributes[str(storm_var) + "_current"][0]
        for env_var in env_vars:
            nc_env_var = out_file.createVariable(grib_info.loc[env_var, "var_name"] + "_prev",
                                                 "f4", dimensions=("p", "y", "x"), zlib=True)
            nc_env_var.long_name = grib_info.loc[env_var, "long_name"] + " from previous hour"
            nc_env_var.units = grib_info.loc[env_var, "units"]
            nc_env_var.coordinates = "p latitude longitude"
            for s, storm_patch in enumerate(all_storm_patches):
                nc_env_var[s] = storm_patch.attributes[str(env_var) + "_prev"][0]
        out_file.close()
        print("Completed {0} {1} at {2}".format(run_date.strftime("%Y-%m-%d"), member, pd.Timestamp("now").strftime("%Y-%m-%d %H:%M:%S")))
    except Exception as e:
        print(traceback.format_exc())
        raise e


def load_ncar_grib_info():
    """
    Reads the grib table file output by wgrib2 to describe the NCAR ensemble variables and parses the file into
    more useful variables

    Returns:

    """
    ncar_grib_table_file = pkg_resources.resource_filename(pkg_resources.Requirement.parse("hagelslag"),
                                                           "mapfiles/ncar_grib_table.txt")
    raw_grib_data = pd.read_table(ncar_grib_table_file, sep=":", header=None, index_col=0)
    grib_data = pd.DataFrame(index=raw_grib_data.index, columns=["short_name", "long_name", "units", "level",
                                                                 "var_name"], dtype=str)
    grib_data["short_name"] = raw_grib_data[3].str.split().str[0]
    grib_data["long_name"] = raw_grib_data[3].str.split().str[1:-1].str.join(" ")
    grib_data["units"] = raw_grib_data[3].str.split().str[-1].str.strip("[]")
    grib_data["level"] = raw_grib_data[5]
    grib_data["var_name"] = grib_data["long_name"].str.replace(" ", "_").str.cat(
        grib_data["level"].str.replace(" ", "_"), sep="_").str.lower()
    return grib_data


def load_ncar_us_mask():
    us_mask_filename = pkg_resources.resource_filename(pkg_resources.Requirement.parse("hagelslag"),
                                                       "mapfiles/ncar_2015_us_mask.nc")
    us_mask_file = Dataset(us_mask_filename)
    us_mask = us_mask_file.variables["usa_mask"][:]
    us_mask_file.close()
    return us_mask

if __name__ == "__main__":
    main()
