#!/usr/bin/env python
import argparse
import traceback
from hagelslag.util.Config import Config
import pandas as pd
import numpy as np
import os
from glob import glob
from multiprocessing import Pool
from hagelslag.evaluation.ObjectEvaluator import ObjectEvaluator
from hagelslag.evaluation.GridEvaluator import GridEvaluator
from hagelslag.evaluation.NeighborEvaluator import NeighborEvaluator
from hagelslag.data.MRMSGrid import MRMSGrid
from hagelslag.data.ModelOutput import ModelOutput
from hagelslag.data.HailForecastGrid import HailForecastGrid
from datetime import datetime, timedelta
from netCDF4 import Dataset
from os.path import isdir
score_counter = 0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("config", help="Config file")
    parser.add_argument("-o", "--obj", action="store_true", help="Perform object-based evaluation.")
    parser.add_argument("-g", "--grid", action="store_true", help="Perform grid-based evaluation.")
    parser.add_argument("-n", "--neighbor", action="store_true", help="Perform neighborhood probability evaluation.")
    parser.add_argument("-r", "--reduced", action="store_true", help="Perform reduced (coarse-grid) neighborhood probability evaluation.")
    parser.add_argument("-p", "--proc", type=int, default=1, help="Number of processors.")
    args = parser.parse_args()
    required = ["ensemble_name", "ensemble_members", "start_date", "end_date", "start_hour", "end_hour", "window_sizes",
                "time_skip", "model_names", "model_types", "size_thresholds", "forecast_json_path",
                "track_data_csv_path", "forecast_sample_path", "mrms_path", "mrms_variable", "obs_mask",
                "mask_variable", "forecast_thresholds", "forecast_bins", "out_path", "obj_scores_file",
                "grid_scores_file", "coordinate_file", "lon_bounds", "lat_bounds"]
    print("loading config")
    config = Config(args.config, required_attributes=required)
    print("config loaded")
    print(args.reduced)
    num_procs = args.proc
    if args.obj:
        evaluate_objects(config, num_procs)
    if args.grid:
        evaluate_grids(config, num_procs)
    if args.neighbor:
        evaluate_neighborhood_probabilities(config, num_procs)
    if args.reduced:
        print("loading reduced")
        evaluate_reduced_neighborhood(config, num_procs)
    return


def evaluate_objects(config, num_procs):
    """
    Evaluates the individual object forecasts in parallel.

    Args:
        config: Config object
        num_procs: Number of processors to use

    Returns:

    """
    pool = Pool(num_procs)
    run_dates = pd.DatetimeIndex(start=config.start_date,
                                 end=config.end_date, freq="1D")
    ensemble_members = config.ensemble_members
    score_columns = ["Run_Date", "Ensemble_Name", "Ensemble_Member", "Model_Name", "Model_Type", "Forecast_Hour"]

    def append_scores(score_dict):
        global score_counter
        if score_counter == 0:
            for model_type, scores in score_dict.iteritems():
                scores.to_csv(config.out_path + config.obj_scores_file + "{0}.csv".format(model_type),
                              index_label="Index")
        else:
            for model_type, scores in score_dict.iteritems():
                scores.to_csv(config.out_path + config.obj_scores_file + "{0}.csv".format(model_type),
                              mode="a", index_label="Index", header=False)
        score_counter += 1
    for run_date in run_dates:
        for member in ensemble_members:
            pool.apply_async(evaluate_object_run, (run_date, member, config, score_columns),
                             callback=append_scores)
    pool.close()
    pool.join()


def evaluate_object_run(run_date, ensemble_member, config, score_columns):
    """
    Evaluates a single ensember member run for object based statistics. CRPS, ROC curves, and Reliability
    statistics are calculated.

    Args:
        run_date (datetime.datetime): Date of the initial model run time
        ensemble_member (str): Name of the ensemble member
        config: Config object
        score_columns (list): Metadata columns for the verification DataFrame.

    Returns:
        A dictionary where the keys correspond to each type of model being evaluated (e.g. condition, dist),
        and the values are pandas DataFrames with the verification score values.
    """
    try:
        run_id = "{0} {1}".format(run_date, ensemble_member) 
        print("Starting " + run_id)
        object_eval = ObjectEvaluator(run_date, config.ensemble_name, ensemble_member, config.model_names,
                                      config.model_types, config.forecast_bins, config.dist_thresholds,
                                      config.forecast_json_path, config.track_data_csv_path)
        print("Loading forecasts " + run_id)
        object_eval.load_forecasts()
        print("Loading obs" + run_id)
        object_eval.load_obs()
        print("Merging obs" + run_id)
        object_eval.merge_obs()
        scores = {}
        for model_type in config.model_types:
            if model_type == "condition":
                extra_columns = ["ROC_0", "Rel_0"]
            elif model_type == "dist":
                extra_columns = ["ROC_{0:d}".format(int(t)) for t in config.object_thresholds]
                extra_columns += ["Rel_{0:d}".format(int(t)) for t in config.object_thresholds]
                extra_columns += ["CRPS", "CRPS_Max"]
            else:
                extra_columns = ["ROC_{0:d}".format(int(t)) for t in config.object_thresholds]
                extra_columns += ["Rel_{0:d}".format(int(t)) for t in config.object_thresholds]
                extra_columns += ["CRPS"]
            scores[model_type] = pd.DataFrame(columns=score_columns + extra_columns)
        forecast_hours = np.arange(config.start_hour, config.end_hour + 1, dtype=int)
        for model_type in config.model_types:
            for model_name in config.model_names[model_type]:
                for forecast_hour in forecast_hours:
                    print(run_date, ensemble_member, model_type, model_name, forecast_hour)
                    index = "{0}_{1}_{2}_{3}_{4}_{5:d}".format(run_date.strftime("%Y%m%d"), config.ensemble_name,
                                                               ensemble_member, model_name, model_type, forecast_hour)
                    print(index)
                    query = "(Forecast_Hour == {0:d})".format(forecast_hour)
                    crps = None
                    max_crps = None
                    if model_type != "condition":
                        crps = object_eval.crps(model_type, model_name, config.model_names["condition"][0],
                                                config.condition_threshold, query=query)
                    if model_type == "dist":
                        forecast_max_hail = object_eval.sample_forecast_max_hail(model_name,
                                                                                 config.model_names["condition"][0],
                                                                                 config.num_max_samples,
                                                                                 config.condition_threshold,
                                                                                 query=query)
                        obs_max_hail = object_eval.sample_obs_max_hail(model_name,
                                                                       config.num_max_samples,
                                                                       query=query)
                        max_crps = object_eval.max_hail_sample_crps(forecast_max_hail, obs_max_hail)
                    rocs = []
                    rels = []
                    if model_type == "condition":
                        rocs.append(object_eval.roc(model_type, model_name, 1, config.forecast_thresholds, query=query))
                        rels.append(object_eval.reliability(model_type, model_name, 1,
                                                            config.forecast_thresholds, query=query))
                    else:
                        for threshold in config.object_thresholds:
                            rocs.append(object_eval.roc(model_type, model_name,
                                                        threshold, config.forecast_thresholds, query=query))
                            rels.append(object_eval.reliability(model_type, model_name,
                                                                threshold, config.forecast_thresholds, query=query))
                    row = [run_date, config.ensemble_name, ensemble_member, model_name, model_type, forecast_hour]
                    row += rocs + rels
                    if model_type != "condition":
                        row += [crps]
                    if model_type == "dist":
                        row += [max_crps]
                    scores[model_type].loc[index] = row
        return scores
    except Exception as e:
        print(traceback.format_exc())
        raise e


def evaluate_grids(config, num_procs):
    pool = Pool(num_procs)
    run_dates = pd.DatetimeIndex(start=config.start_date,
                                 end=config.end_date, freq="1D")
    ensemble_members = config.ensemble_members
    score_columns = ["Run_Date", "Ensemble_Name", "Ensemble_Member", "Model_Name", "Size_Threshold", "Window_Size",
                     "Window_Start", "Window_End", "ROC", "Reliability"]
    score_list = []

    def append_scores(score_set):
        score_list.append(score_set)
        if len(score_list) == 1:
            score_set.to_csv(config.out_path + config.grid_scores_file, index_label="Index")
        else:
            score_set.to_csv(config.out_path + config.grid_scores_file, mode="a", index_label="Index", header=False)
    for window_size in config.window_sizes:
        for run_date in run_dates:
            for member in ensemble_members:
                pool.apply_async(evaluate_grid_run, (run_date, member, window_size, config,
                                                     score_columns),
                                 callback=append_scores)
    pool.close()
    pool.join()
    all_scores = pd.concat(score_list)
    return all_scores


def evaluate_grid_run(run_date, member, window_size, config, score_columns):
    """
    Calculate verification statistics on a single ensemble member run for a single machine learning model.

    :param run_date:
    :param member:
    :param window_size:
    :param config:
    :param score_columns:
    :return:
    """
    try:
        print("Starting {0} {1:02d} hours Run: {2}".format(member, window_size, run_date.strftime("%Y%m%d")))
        grid_eval = GridEvaluator(run_date, config.ensemble_name, member, config.model_names, config.size_thresholds,
                                  config.start_hour, config.end_hour, window_size, config.time_skip,
                                  config.forecast_sample_path, config.mrms_path, config.mrms_variable,
                                  config.obs_mask, config.mask_variable)
        print("Loading forecasts {0} {1:02d} hours Run: {2}".format(member, window_size, run_date.strftime("%Y%m%d")))
        grid_eval.load_forecasts()
        print("Window forecasts {0} {1:02d} hours Run: {2}".format(member, window_size, run_date.strftime("%Y%m%d")))
        grid_eval.get_window_forecasts()
        print("Load obs {0} {1:02d} hours Run: {2}".format(member, window_size, run_date.strftime("%Y%m%d")))
        grid_eval.load_obs()
        grid_eval.dilate_obs(config.dilation_radius)
        print("Scoring ROC {0} {1:02d} hours Run: {2}".format(member, window_size, run_date.strftime("%Y%m%d")))
        roc_curves = grid_eval.roc_curves(config.forecast_thresholds)
        print("Scoring Rel {0} {1:02d} hours Run: {2}".format(member, window_size, run_date.strftime("%Y%m%d")))
        rel_curves = grid_eval.reliability_curves(config.forecast_thresholds)
        output_scores = pd.DataFrame(columns=score_columns)
        for ml_model_name in roc_curves.keys():
            for size_threshold in roc_curves[ml_model_name].keys():
                for hour_window in roc_curves[ml_model_name][size_threshold].keys():
                    row = [run_date, config.ensemble_name, member, ml_model_name, size_threshold, window_size,
                           hour_window[0], hour_window[1], roc_curves[ml_model_name][size_threshold][hour_window],
                           rel_curves[ml_model_name][size_threshold][hour_window]]
                    index = "{0}_{1}_{2}_{3}_{4:02d}_{5:02d}_{6:02d}".format(run_date.strftime("%Y%m%d"),
                                                                             config.ensemble_name,
                                                                             member,
                                                                             ml_model_name.replace(" ", "-"),
                                                                             size_threshold,
                                                                             window_size,
                                                                             hour_window[0])
                    print(index)
                    output_scores.loc[index] = row
        return output_scores
    except Exception as e:
        print(traceback.format_exc())
        raise e


def evaluate_neighborhood_probabilities(config, num_procs):
    """
    Calculate evaluation statistics for a set of neighborhood probability forecasts.

    Args:
        config (hagelslag.util.Config object): Object containing configuration parameters as attributes.
        num_procs (int): Number of processors to use for the calculations
    """
    pool = Pool(num_procs)
    run_dates = pd.DatetimeIndex(sorted(os.listdir(config.neighbor_path)))
    print(run_dates)
    if os.access(config.neighbor_score_path + "period_scores.csv", os.R_OK):
        os.remove(config.neighbor_score_path + "period_scores.csv")
    if os.access(config.neighbor_score_path + "hour_scores.csv", os.R_OK):
        os.remove(config.neighbor_score_path + "hour_scores.csv")

    def save_scores(output):
        #hour_scores = output[0]
        #period_scores = output[1]
        period_scores = output
        if not os.access(config.neighbor_score_path + "period_scores.csv", os.R_OK):
            #hour_scores.to_csv(config.neighbor_score_path + "hour_scores.csv", index_label="Index")
            period_scores.to_csv(config.neighbor_score_path + "period_scores.csv", index_label="Index")
        else:
           # hour_scores.to_csv(config.neighbor_score_path + "hour_scores.csv", index_label="Index", mode="a",
           #                    header=False)
            period_scores.to_csv(config.neighbor_score_path + "period_scores.csv", index_label="Index", mode="a",
                                 header=False)

    for run_date in run_dates:
        forecast_files = glob(config.neighbor_path + "{0}/*.nc".format(run_date.strftime("%Y%m%d")))
        for forecast_file in forecast_files:
            file_comps = forecast_file.split("/")[-1].split("_")
            ensemble_name = file_comps[0]
            model_name = file_comps[1]
            forecast_variable = "_".join(file_comps[2:file_comps.index("consensus")])
            if forecast_variable in config.neighbor_thresholds.keys():
                pool.apply_async(evaluate_single_neighborhood, (run_date, config.start_hour, config.end_hour,
                                                                ensemble_name,
                                                                model_name, forecast_variable, config.mrms_variable,
                                                                config.neighbor_radii, config.smoothing_radii,
                                                                config.obs_thresholds,
                                                                config.neighbor_thresholds[forecast_variable],
                                                                config.forecast_thresholds,
                                                                config.obs_mask, config.mask_variable,
                                                                config.neighbor_path, config.mrms_path,
                                                                config.coordinate_file, config.lon_bounds,
                                                                config.lat_bounds),
                                 callback=save_scores)
    pool.close()
    pool.join()


def evaluate_single_neighborhood(run_date, start_hour, end_hour, ensemble_name, model_name, forecast_variable,
                                 mrms_variable,
                                 neighbor_radii, smoothing_radii, obs_thresholds, size_thresholds, probability_levels,
                                 obs_mask, mask_variable, forecast_path, mrms_path, coordinate_file, lon_bounds,
                                 lat_bounds):
    """
    Calculate verification statistics for the neighborhood ensemble probabilities from a single model run.

    Args:
        run_date (datetime.datetime): Date of the model run
        start_hour (int): First hour being evaluated
        end_hour (int): Last hour being evaluated (inclusive)
        ensemble_name (str): Name of the ensemble system
        model_name (str): Name of the numerical or machine learning model type
        forecast_variable (str): Name of the diagnostic variable or machine learning model
        mrms_variable (str): MRMS variable used for verification
        neighbor_radii (list): Radii in grid points for neighborhood convolution filter
        smoothing_radii (list): Radii of Gaussian smoothing filter
        obs_thresholds (list): Thresholds of the observed variable to test
        size_thresholds (list): Intensity thresholds for neighborhood probabilities
        probability_levels (list): Probablity thresholds for the ROC Curve and Reliability diagram
        obs_mask (bool): True if grid mask is being used
        mask_variable (str): Name of the MRMS variable used for masking the grid
        forecast_path (str): Path to neighborhood probability netCDF files
        mrms_path (str): Path to MRMS netCDF files
        coordinate_file (str): Name of the file containing lat-lon coordinates
        lon_bounds (list): lower and upper bounds of longitude bounding box
        lat_bounds (list): lower and upper bounds of latitude bounding box

    Returns:
        A tuple of DataFrames containing hourly scores and full-period scores respectively.
    """
    try:
        ne = NeighborEvaluator(run_date, start_hour, end_hour, ensemble_name, model_name, forecast_variable,
                               mrms_variable, neighbor_radii, smoothing_radii, obs_thresholds, size_thresholds,
                               probability_levels, obs_mask, mask_variable, forecast_path, mrms_path, coordinate_file,
                               lon_bounds, lat_bounds)
        ne.load_forecasts()
        ne.load_obs()
        ne.load_coordinates()
        period_scores = ne.evaluate_period_forecasts()
        #hourly_scores = ne.evaluate_hourly_forecasts()
        return period_scores
        #return hourly_scores, period_scores
    except Exception as e:
        print(traceback.format_exc())
        raise e


def evaluate_reduced_neighborhood(config, num_procs):
    pool = Pool(num_procs)
    run_dates = pd.DatetimeIndex(start=config.start_date, end=config.end_date, freq="1D")
    for run_date in run_dates.to_pydatetime():
        print(run_date)
        pool.apply_async(evaluate_reduced_neighborhood_run, (run_date, config.start_hour, config.end_hour,
                                                             config.ensemble_name, config.ensemble_members,
                                                             config.ensemble_variables, config.model_names["dist"],
                                                             "hail", config.mrms_variable, config.neighbor_thresholds,
                                                             config.neighbor_thresholds["dist"], config.obs_thresholds,
                                                             config.stride, config.neighbor_radius, config.neighbor_sigma,
                                                             config.ensemble_path, config.ml_grid_path, config.mrms_path,
                                                             config.coarse_neighbor_out_path, config.map_file, config.us_mask_file, 
                                                             config.single_step))
    pool.close()
    pool.join()
    return


def evaluate_reduced_neighborhood_run(run_date, start_hour, end_hour,
                                      ensemble_name, ensemble_members, ensemble_variables, ml_models,
                                      ml_variable, mrms_variable, ensemble_thresholds, ml_thresholds, obs_thresholds,
                                      stride, radius, sigma, ensemble_path, ml_grid_path, mrms_path, out_path,
                                      map_file, mask_file, single_step):
    try:
        start_date = run_date + timedelta(hours=start_hour)
        end_date = run_date + timedelta(hours=end_hour)
        map_data = {}
        mask_obj = Dataset(mask_file)
        mask_data = mask_obj.variables["usa_mask"][:]
        mask_obj.close()
        map_data["us_mask"] = mask_data[::stride, ::stride].flatten()
        mrms_data = MRMSGrid(start_date, end_date, mrms_variable, mrms_path)
        mrms_data.load_data()
        mo = ModelOutput(ensemble_name, ensemble_members[0], run_date, ensemble_variables[0],
                         start_date, end_date, ensemble_path)
        mo.load_map_info(map_file)
        for map_var in ["x", "y", "lat", "lon", "i", "j"]:
            map_data[map_var] = getattr(mo, map_var)[::stride, ::stride].flatten()
        thin_indices = np.indices(mask_data[::stride, ::stride].shape)
        map_data["i_small"] = thin_indices[0].flatten()
        map_data["j_small"] = thin_indices[1].flatten()
        eval_data = pd.DataFrame(map_data)
        eval_data["Run_Date"] = pd.Timestamp(run_date)
        eval_data["Start_Date"] = pd.Timestamp(start_date)
        eval_data["End_Date"] = pd.Timestamp(end_date)
        for threshold in obs_thresholds:
            print("Loading MRMS " + run_date.strftime("%Y%m%d"))
            eval_data[mrms_variable + "_{0:d}".format(threshold)] = mrms_data.period_neighborhood_probability(radius, 0, threshold, stride, 
                                                                                                              mo.x / 1000.0, mo.y / 1000.0, 
                                                                                                              mo.dx / 1000.0).ravel()
        for ens_var in ensemble_variables:
            for member in ensemble_members:
                ens_data = ModelOutput(ensemble_name, member, run_date, ens_var, start_date, end_date, ensemble_path, single_step)
                ens_data.load_data()
                print(ens_var, ens_data.units)
                if ens_data.units == "m" or ens_var in ["HAIL_MAXK1", "HAIL_MAX2D"]:
                    ens_data.data = 1000 * ens_data.data
                    ens_data.units = "mm"
                print(ens_var, ens_data.data.max())
                for ens_thresh in ensemble_thresholds[ens_var]:
                    print("Evaluating {0} {1} {2} {3:d} {4}".format(ensemble_name, ens_var, member, ens_thresh, run_date.strftime("%Y%m%d")))
                    col_name = "{0}_{1}_{2}_{3:d}".format(ensemble_name, ens_var, member, ens_thresh)
                    eval_data[col_name] = ens_data.period_neighborhood_probability(radius, sigma, ens_thresh, stride, mo.x, mo.y, mo.dx).ravel()
            for ens_thresh in ensemble_thresholds[ens_var]:
                col_names = ["{0}_{1}_{2}_{3:d}".format(ensemble_name, ens_var, m, ens_thresh) for m in ensemble_members]
                mean_col = "{0}_{1}_mean_{2:d}".format(ensemble_name, ens_var, ens_thresh)
                eval_data[mean_col] = eval_data[col_names].mean(axis=1)
        for ml_model in ml_models:
            ml_data = HailForecastGrid(run_date, start_date, end_date, ensemble_name, ml_model.replace(" ", "-"), ensemble_members, ml_variable, 2, ml_grid_path)
            ml_data.load_data()
            for ml_thresh in ml_thresholds:
                print("Evaluating {0} {1} {2:d} {3}".format(ensemble_name, ml_model, ml_thresh, run_date.strftime("%Y%m%d")))
                ml_neighbor_prob = ml_data.period_neighborhood_probability(radius, sigma, ml_thresh, stride)
                for m, member in enumerate(ensemble_members):
                    col_name = "{0}_{1}_{2}_{3:d}".format(ensemble_name, ml_model.replace(" ", "-"), member, ml_thresh)
                    eval_data[col_name] = ml_neighbor_prob[m].ravel()
                mean_col_name = "{0}_{1}_{2}_{3:d}".format(ensemble_name, ml_model.replace(" ", "-"), "mean", ml_thresh)
                eval_data[mean_col_name] = ml_neighbor_prob.mean(axis=0).ravel()
        eval_data.to_csv(out_path + "coarse_neighbor_eval_{0}_{1}.csv".format(ensemble_name, run_date.strftime("%Y%m%d")), index=False, na_rep="NaN")
    except Exception as e:
        print(traceback.format_exc())
        raise e
    return

if __name__ == "__main__":
    main()
