#!/usr/bin/env python
from netCDF4 import Dataset
import numpy as np
import argparse
from hagelslag.data.WRFModelGrid import WRFModelGrid
from hagelslag.processing import Hysteresis
from scipy.ndimage import find_objects, gaussian_filter
from glob import glob
from datetime import datetime
import traceback
from mpi4py import MPI
import os.path
def main():
    wrf_vars = ["TEMP", "XLAT", "XLONG", "T2", "Q2", "U10", "V10", "PSFC", "U", "V", "W", "GPH", 
                "PRES", "HGT", "REFL_10CM",
                "QVAPOR", "QRAIN", "QICE", "QSNOW", "QGRAUP", "QHAIL", "QCLOUD",
                "QNRAIN", "QNICE", "QNSNOW", "QNGRAUPEL", "QNHAIL", "QNCLOUD",
                "QICE2", "QNICE2", "QVGRAUPEL", "QNGRAUPEL2", "QVGRAUPEL2", "V_ICE",
                "D_ICE", "RHO_ICE"]
    parser = argparse.ArgumentParser()
    parser.add_argument("-p", "--path", help="Path to wrfout files")
    parser.add_argument("-o", "--out", help="Path where WRF object netCDF files are written")
    parser.add_argument("-v", "--var", default="REFL_10CM", help="Variable for object finding")
    parser.add_argument("-u", "--up", default=45, type=float, help="Upper threshold for hysteresis")
    parser.add_argument("-l", "--low", default=20, type=float, help="Lower threshold for hysteresis")
    parser.add_argument("-s", "--smooth", default=1, type=int, help="Smoothing radius")
    parser.add_argument("-a", "--area", default=16, type=int, help="Minimum object size")
    parser.add_argument("-d", "--dom", default="d01", help="WRF domain")
    parser.add_argument("-n", "--nprocs", default=1, type=int, help="Number of processors")
    parser.add_argument("-e", "--ext", default=",".join(wrf_vars),
                        help="Variables to be extracted from WRF as comma separated list")
    args = parser.parse_args()
    wrf_files = [a.replace(args.path, "") for a in sorted(glob(args.path + "wrfout_{0}_*".format(args.dom)))]
    wrf_vars = args.ext.split(",")
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    proc_file_index = np.arange(rank, len(wrf_files), size)
    for file_index in proc_file_index:
        valid_date = datetime.strptime(wrf_files[file_index][11:], "%Y-%m-%d_%H:%M:%S")
        print("Proc {0:d} of {1:d} loading {2}".format(int(rank), int(size), valid_date)) 
        extract_wrf_timestep(valid_date, args.var, args.up, args.low, args.smooth, args.area,
                             args.dom, args.path, wrf_vars, args.out)
       
    return


def extract_wrf_timestep(valid_date, variable, upper_threshold, lower_threshold, smoothing, min_size,
                         domain, wrf_path, wrf_vars, out_path):
    try:
        object_slices, grid_shape, global_attributes, time_var, time_attrs = find_wrf_objects(valid_date, 
                                                                                              variable, 
                                                                                              upper_threshold,
                                                                                              lower_threshold, 
                                                                                              smoothing,
                                                                                              min_size, 
                                                                                              domain, wrf_path)
        if len(object_slices) > 0:
            extract_wrf_boxes(object_slices, grid_shape, global_attributes, valid_date,
                              wrf_vars, domain, wrf_path, out_path, time_var, time_attrs)
    except Exception as e:
        print(traceback.format_exc())
        raise e


def find_wrf_objects(valid_date, variable, upper_threshold, lower_threshold, smoothing, min_size, domain, path):
    print("Loading " + variable + " " + valid_date.strftime("%Y%m%d %H:%M"))
    wrf_grid = WRFModelGrid(valid_date, variable, domain, path)
    var_data, var_attrs = wrf_grid.load_full_grid()
    time_var, time_attrs = wrf_grid.load_time_var()
    global_attributes = wrf_grid.get_global_attributes()
    grid_shape = var_data.shape
    if len(var_data.shape) == 4:
        var_data = var_data[0].max(axis=0)
    print("Finding objects at " + valid_date.strftime("%Y%m%d %H:%M"))
    hyst = Hysteresis(lower_threshold, upper_threshold)
    if smoothing > 0:
        label_grid = hyst.label(gaussian_filter(var_data, smoothing, mode="constant"))
    else:
        label_grid = hyst.label(var_data)
    filtered_grid = hyst.size_filter(label_grid, min_size)
    object_slices = find_objects(filtered_grid, filtered_grid.max())
    print("Found {0:d} objects at ".format(len(object_slices)) + valid_date.strftime("%Y%m%d %H:%M"))
    return object_slices, grid_shape, global_attributes, time_var, time_attrs


def extract_wrf_boxes(object_slices, grid_shape, global_attributes, valid_date, 
                      wrf_vars, domain, wrf_path, out_path, time_var, time_attrs):

    for o, object_slice in enumerate(object_slices):
        out_filename = out_path + "wrfobjects_{0}_{1}_{2:04d}.nc".format(domain,
                                                                         valid_date.strftime("%Y-%m-%d_%H%M%S"),
                                                                         o)
        out_data = Dataset(out_filename, mode="w")
        for attribute in global_attributes.keys():
            setattr(out_data, attribute, global_attributes[attribute])
        setattr(out_data, "Conventions", "CF-1.6")
        out_data.createDimension("time", 1)
        out_data.createDimension("south_north", object_slice[0].stop - object_slice[0].start)
        out_data.createDimension("west_east", object_slice[1].stop - object_slice[1].start)
        out_data.createDimension("bottom_top", grid_shape[1])
        
        setattr(out_data, "SOUTH_NORTH_BOX_START_C", object_slice[0].start)
        setattr(out_data, "SOUTH_NORTH_BOX_END_C", object_slice[0].stop)
        setattr(out_data, "WEST_EAST_BOX_START_C", object_slice[1].start)
        setattr(out_data, "WEST_EAST_BOX_END_C", object_slice[1].stop)

        for t_var_name in ["XTIME", "time"]:
            tv = out_data.createVariable(t_var_name, "i4", ("time",))
            tv[:] = time_var * 60
            for k, v in time_attrs.items():
                setattr(tv, k, v)
            tv.long_name = "time"
            tv.standard_name = "time"
            tv.units = tv.units.replace("minutes", "seconds")
        if "DX" in global_attributes.keys():
            dx = global_attributes["DX"]
        else:
            dx = 3000.0
        x_var = out_data.createVariable("x", "f4", ("west_east",))
        x_var[:] = np.arange(object_slice[1].start * dx, object_slice[1].stop * dx, dx)
        x_var.units = "m"
        x_var.long_name = "x"
        x_var.standard_name = "x"
        x_var.axis = "X"
        y_var = out_data.createVariable("y", "f4", ("south_north",))
        y_var[:] = np.arange(object_slice[0].start * dx, object_slice[0].stop * dx, dx)
        y_var.units = "m"
        y_var.long_name = "y"
        y_var.standard_name = "y"
        y_var.axis = "Y"
        out_data.close()
    wrf_file_obj = WRFModelGrid(valid_date, wrf_vars[0], domain, wrf_path)
    for wrf_var in wrf_vars:
        print("Extracting {0} at {1}".format(wrf_var, valid_date.strftime("%Y-%m-%d_%H:%M:%S")))
        if wrf_var == "TEMP":
            wrf_file_obj.variable = "P"
            pres = wrf_file_obj.load_full_grid()[0]
            wrf_file_obj.variable = "PB"
            pres += wrf_file_obj.load_full_grid()[0]
            wrf_file_obj.variable = "T"
            temp, var_attrs = wrf_file_obj.load_full_grid()
            var_attrs["description"] = "Temperature"
            wrf_var_grid = (temp + 300.0) * (pres / 100000.0) ** (287.0 / 1005.0)
            del pres
            del temp
        elif wrf_var == "GPH":
            wrf_file_obj.variable = "PH"
            wrf_var_grid, var_attrs = wrf_file_obj.load_full_grid()
            wrf_file_obj.variable = "PHB"
            wrf_var_grid += wrf_file_obj.load_full_grid()[0]
            wrf_var_grid /= 9.81
            var_attrs["description"] = "Geopotential height"
            var_attrs["units"] = "m"
        elif wrf_var == "PRES":
            wrf_file_obj.variable = "P"
            wrf_var_grid, var_attrs = wrf_file_obj.load_full_grid()
            wrf_file_obj.variable = "PB"
            wrf_var_grid += wrf_file_obj.load_full_grid()[0]
            wrf_var_grid *= 0.01
            var_attrs["description"] = "Pressure"
            var_attrs["units"] = "hPa"
        else:
            wrf_file_obj.variable = wrf_var
            wrf_var_grid, var_attrs = wrf_file_obj.load_full_grid()
        if wrf_var_grid is not None:
            var_attrs["long_name"] = var_attrs["description"]
            for o, object_slice in enumerate(object_slices):
                out_filename = out_path + "wrfobjects_{0}_{1}_{2:04d}.nc".format(domain,
                                                                                 valid_date.strftime("%Y-%m-%d_%H%M%S"),
                                                                                 o)
                out_data = Dataset(out_filename, mode="a")
                if wrf_var == "GPH":
                    z_var = out_data.createVariable("z", "f4", ("bottom_top", "south_north", "west_east"), zlib=True)
                    z_var[:] = wrf_var_grid[0, :, object_slice[0], object_slice[1]]
                    z_var.long_name = "z"
                    z_var.standard_name = "z"
                    z_var.units = "m"
                    z_var.axis = "Z"
                    z_var.positive = "up"
                if len(wrf_var_grid.shape) == 4:
                    var_data = wrf_var_grid[:, :, object_slice[0], object_slice[1]]
                    var = out_data.createVariable(wrf_var, "f4",
                                                  ("time", "bottom_top",
                                                   "south_north", "west_east"),
                                                  zlib=True)
                    var[:] = var_data
                else:
                    var_data = wrf_var_grid[:, object_slice[0], object_slice[1]]
                    var = out_data.createVariable(wrf_var, "f4",
                                                  ("time",
                                                   "south_north", "west_east"),
                                                  zlib=True)
                    var[:] = var_data
                for k, v in var_attrs.items():
                    setattr(var, k, v)
                out_data.close()
            del wrf_var_grid
    return


if __name__ == "__main__":
    package_directory = os.path.dirname(os.path.abspath(__file__))
    print(package_directory)
    main()
