#!/usr/bin/env python
import numpy as np
import pandas as pd
from hagelslag.data import ModelOutput
from hagelslag.util.make_proj_grids import read_ncar_map_file, read_arps_map_file, make_proj_grids
import argparse
from multiprocessing import Pool
from datetime import datetime, timedelta
from scipy.spatial.distance import cdist
from netCDF4 import Dataset, chartostring, num2date
import traceback
from os.path import exists


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config", help="Config file name")
    parser.add_argument("-p", "--proc", type=int, default=1, help="Number of processors")
    args = parser.parse_args()
    config = read_config_file(args.config)
    run_dates = config["run_dates"]
    ensemble_members = config["ensemble_members"]
    if args.proc > 1:
        pool = Pool(args.proc)
        for run_date in run_dates:
            for ensemble_member in ensemble_members:
                pool.apply_async(extract_model_run, (run_date, ensemble_member), config)
        pool.close()
        pool.join()
    else:
        for run_date in run_dates:
            for ensemble_member in ensemble_members:
                extract_model_run(run_date, ensemble_member, **config)
    return


def read_config_file(filename):
    config = {}
    config_file = open(filename)
    config_text = config_file.read()
    config_file.close()
    exec (config_text)
    return config


def load_mesonet_station_info(filename, valid_date):
    valid_datetime = np.datetime64(valid_date)
    mesonet_info = pd.read_csv(filename, parse_dates=["datc", "datd"])
    mesonet_info = mesonet_info.loc[(mesonet_info["datc"].values < valid_datetime) & (valid_datetime < mesonet_info["datd"].values)]
    mesonet_info.reset_index(inplace=True)
    return mesonet_info


def load_mesonet_nc_files(start_date, end_date, variables, stations, path):
    obs_hours = pd.DatetimeIndex(start=start_date, end=end_date, freq="1H")
    obs_dates = np.unique(obs_hours.date)
    all_hours = []
    all_stations = []
    for obs_hour in obs_hours:
        for station in stations:
            all_hours.append(obs_hour)
            all_stations.append(station)
    obs_data = pd.DataFrame(index=range(len(all_hours)), columns=["Valid_Date", "Station_ID"] + variables)
    obs_data["Valid_Date"] = all_hours
    obs_data["Station_ID"] = all_stations
    for obs_date in obs_dates:
        filename = path + "mesonet." + obs_date.strftime("%Y%m%d") + ".nc"
        meso_file_data = Dataset(filename)
        station_names = chartostring(meso_file_data.variables["stationName"][:])
        meso_dates = pd.DatetimeIndex(num2date(meso_file_data.variables["time_nominal"][:],
                                               units=meso_file_data.variables["time_nominal"].units))
        valid_dates = np.intersect1d(obs_hours.values, meso_dates.values)
        valid_stations = np.intersect1d(stations, station_names)
        for variable in variables:
            var_frame = pd.DataFrame(meso_file_data.variables[variable][:].T, columns=station_names, index=meso_dates)
            obs_data.loc[np.in1d(obs_data["Valid_Date"], valid_dates) &
                         np.in1d(obs_data["Station_ID"], valid_stations), variable] = var_frame.loc[valid_dates,
                                                                                                    valid_stations].values.ravel()
        meso_file_data.close()
    return obs_data


def extract_model_run(run_date, ensemble_member, **kwargs):
    """
    Extracts CAM model output at each specified station for a given range of forecast times for a single member of
    a single model run.

    Args:
        run_date:
        ensemble_member:

    Keyword Args:


    Returns:

    """
    try:
        print("Starting {0} {1}".format(run_date, ensemble_member))
        ensemble_name = kwargs["ensemble_name"]
        if type(run_date) != datetime:
            run_date = run_date.to_pydatetime()
        start_hour = kwargs["start_hour"]
        end_hour = kwargs["end_hour"]
        start_date = run_date + timedelta(hours=start_hour)
        end_date = run_date + timedelta(hours=end_hour)
        valid_dates = pd.DatetimeIndex(start=start_date, end=end_date, freq="1H")
        single_step = kwargs["single_step"]
        ens_path = kwargs["ens_path"]
        site_variables = kwargs["site_variables"]
        neighbor_stats = kwargs["neighbor_stats"]
        neighbor_variables = kwargs["neighbor_variables"]
        neighbor_radius = kwargs["neighbor_radius"]
        map_filename = kwargs["map_filename"]
        obs_path = kwargs["obs_path"]
        obs_vars = kwargs["obs_vars"]
        out_path = kwargs["out_path"]
        reg_path_exists = exists(ens_path + ensemble_member + "/" + run_date.strftime("%Y%m%d"))
        alt_path_exists = exists(ens_path + run_date.strftime("%Y%m%d") + "/" + ensemble_member)
        if not reg_path_exists and not alt_path_exists:
            print(ensemble_member + " " + run_date.strftime("%Y%m%d") + " not found.")
            return
        mesonet_info = load_mesonet_station_info(kwargs["station_info_file"], run_date)
        if ensemble_name == "SSEF":
            proj_dict, grid_dict = read_arps_map_file(map_filename)
        else:
            proj_dict, grid_dict = read_ncar_map_file(map_filename)
        map_info = make_proj_grids(proj_dict, grid_dict)
        model_rows, model_cols = np.indices(map_info["lon"].shape)
        model_coords = np.vstack((map_info["lon"].ravel(), map_info["lat"].ravel())).T
        nearest_grid_index = cdist(mesonet_info[["elon", "nlat"]].values, model_coords).argmin(axis=1)
        nearest_row = model_rows.ravel()[nearest_grid_index]
        nearest_col = model_cols.ravel()[nearest_grid_index]
        metadata_cols = ["Valid_Date", "Run_Date", "Forecast_Hour", "Station_ID", "Lon", "Lat", "Elevation",
                         "Ensemble_Name", "Ensemble_Member", "Row", "Col"]
        num_timesteps = end_hour - start_hour + 1
        num_sites = len(nearest_row)
        neighbor_stat_vars = []
        for neighbor_var in neighbor_variables:
            for stat in neighbor_stats:
                neighbor_stat_vars.append(neighbor_var + "_" + stat)
        site_data = pd.DataFrame(index=np.arange(num_timesteps * nearest_row.size),
                                 columns=metadata_cols + site_variables)
        neighbor_columns = []
        for var in neighbor_variables:
            for r in range(-neighbor_radius, neighbor_radius + 1):
                for c in range(-neighbor_radius, neighbor_radius + 1):
                    neighbor_columns.append("{0}_R_{1:d}_C_{2:d}".format(var, r, c))
        neighbor_data = pd.DataFrame(index=np.arange(num_timesteps * nearest_row.size),
                                     columns=metadata_cols + neighbor_columns)
        site_data["Run_Date"] = run_date
        site_data["Valid_Date"] = np.repeat(valid_dates, nearest_row.size)
        site_data["Forecast_Hour"] = np.repeat(np.arange(start_hour, end_hour + 1), nearest_row.size)
        site_data["Station_ID"] = np.tile(mesonet_info["stid"].values, (num_timesteps,))
        site_data.loc[:,
        ["Lon", "Lat", "Elevation"]] = np.tile(mesonet_info[["elon", "nlat", "elev"]].values,
                                               (num_timesteps, 1))
        site_data["Ensemble_Name"] = ensemble_name
        site_data["Ensemble_Member"] = ensemble_member
        site_data["Row"] = np.tile(nearest_row, (num_timesteps,))
        site_data["Col"] = np.tile(nearest_col, (num_timesteps,))
        for m_col in metadata_cols:
            neighbor_data[m_col] = site_data[m_col]
        for variable in np.union1d(site_variables, neighbor_variables):
            print("Extracting {0} from {1} {2}".format(variable, run_date.strftime("%Y%m%d"), ensemble_member))
            mo = ModelOutput(ensemble_name, ensemble_member, run_date, variable, start_date, end_date, ens_path,
                             single_step=single_step)
            mo.load_data()
            if variable in site_variables:
                site_data[variable] = mo.data[:, nearest_row, nearest_col].ravel()
            if variable in neighbor_variables:
                n_var_i_start = len(metadata_cols) + neighbor_variables.index(variable) * (neighbor_radius * 2 + 1) ** 2
                n_var_i_stop = n_var_i_start + (neighbor_radius * 2 + 1) ** 2 
                for s in range(num_sites):
                    site_idx = np.arange(num_timesteps) * num_sites + s
                    neighbor_data.iloc[site_idx,
                                       n_var_i_start:n_var_i_stop] = mo.data[:,
                                                                             nearest_row[s] - neighbor_radius:
                                                                             nearest_row[s] + neighbor_radius + 1,
                                                                             nearest_col[s] - neighbor_radius:
                                                                             nearest_col[s] + neighbor_radius + 1
                                                                             ].reshape(num_timesteps,
                                                                                       n_var_i_stop - n_var_i_start)
                    for stat in neighbor_stats:
                        site_data[variable + "_" + stat] = getattr(neighbor_data.iloc[site_idx,
                                                                                      n_var_i_start:n_var_i_stop],
                                                                   stat)(axis=1)
        obs_data = load_mesonet_nc_files(start_date, end_date, obs_vars, mesonet_info["stid"].values, obs_path)
        for obs_var in obs_vars:
            site_data[obs_var] = obs_data[obs_var]
            neighbor_data[obs_var] = obs_data[obs_var]
        site_data.to_csv(out_path + "site_data_{0}_{1}_{2}.csv".format(ensemble_name, ensemble_member,
                                                                       run_date.strftime("%Y%m%d")), index=False, na_rep="nan")
        neighbor_data.to_csv(out_path + "neighbor_data_{0}_{1}_{2}.csv".format(ensemble_name, ensemble_member,
                                                                               run_date.strftime("%Y%m%d")),
                             index=False, na_rep="nan")
    except Exception as e:
        print(traceback.format_exc())
        raise e
    return


if __name__ == "__main__":
    main()
