#!/usr/bin/env python
import argparse, pdb
from multiprocessing import Pool
from hagelslag.util.Config import Config
from hagelslag.processing.TrackProcessing import TrackProcessor
from hagelslag.util.make_proj_grids import read_ncar_map_file
from hagelslag.util.create_sector_grid_data import SectorProcessor
from datetime import timedelta
import pandas as pd
import numpy as np
import os
import traceback
from netCDF4 import Dataset, date2num
from os.path import join, exists


def main():
    parser = argparse.ArgumentParser("hsdata - Hagelslag Data Processor")
    parser.add_argument("config", help="Configuration file")
    parser.add_argument("-r", "--rematch", action="store_true", help="Rematch existing forecast and observed tracks")
    parser.add_argument("-o", "--obs", action="store_true", help="Process observed tracks only")
    parser.add_argument("-j", "--json", action="store_true", default=False, help="Output forecast geoJSON files")
    parser.add_argument("-p", "--proc", type=int, default=1,
                        help="Number of processors")
    args = parser.parse_args()
    required = ['dates', 'start_hour', 'end_hour', 'ensemble_members',
                'watershed_variable', 'model_path', "ensemble_name",
                'model_watershed_params', 'object_matcher_params',
                'track_matcher_params', 'size_filter', 'gaussian_window',
                'mrms_path', 'mrms_watershed_params',
                'storm_variables', 'potential_variables', "tendency_variables",
                'variable_statistics', 'model_map_file',
                'csv_path', 'train', 'single_step', "unique_matches", "label_type",
                'closest_matches']
    config = Config(args.config, required_attributes=required)
    if not exists(config.csv_path):
        os.makedirs(config.csv_path)
    if not exists(config.nc_path):
        os.makedirs(config.nc_path)
    if not exists(config.geojson_path):
        os.makedirs(config.geojson_path)

    if not hasattr(config, "segmentation_approach"):
        config.segmentation_approach = "ew"
    print("Seg approach", config.segmentation_approach)
    if not hasattr(config, "run_date_format"):
        config.run_date_format = "%Y%m%d-%H%M"
    if not hasattr(config, "geojson_path"):
        config.geojson_path = None
    config.json = args.json
    if args.proc > 1:
        pool = Pool(args.proc)
        for run_date in config.dates:
            for member in config.ensemble_members:
                if args.obs:
                    pool.apply_async(process_observed_tracks, (run_date, member, config))
                elif args.rematch:
                    pool.apply_async(rematch_ensemble_tracks, (run_date, member, config))
                else:
                    pool.apply_async(process_ensemble_member, (run_date, member, config))
        pool.close()
        pool.join()
    else:
        for run_date in config.dates:
            for member in config.ensemble_members:
                if args.rematch:
                    rematch_ensemble_tracks(run_date, member, config)
                else:
                    process_ensemble_member(run_date, member, config)
    return


def process_ensemble_member(run_date, member, config):
    """
    Find forecast and observed tracks for one run of a storm-scale ensemble member.

    Args:
        run_date: datetime object containing the date of the model run
        member: name of the ensemble member
        config: Config object containing model parameters
    """
    try:
        print("Starting", run_date, member)
        start_date = run_date + timedelta(hours=config.start_hour)
        end_date = run_date + timedelta(hours=config.end_hour)
        if hasattr(config, "mask_file"):
            mask_file = config.mask_file
        else:
            mask_file = None
        if hasattr(config, "match_steps"):
            match_steps = config.match_steps
        else:
            match_steps = False
        if hasattr(config, "patch_radius"):
            patch_radius = config.patch_radius
        else:
            patch_radius = None

        print("Patch Radius", patch_radius)
        track_proc = TrackProcessor(run_date,
                                    start_date,
                                    end_date,
                                    config.ensemble_name,
                                    member,
                                    config.watershed_variable,
                                    config.model_path,
                                    config.model_map_file,
                                    config.model_watershed_params,
                                    config.object_matcher_params,
                                    config.track_matcher_params,
                                    config.size_filter,
                                    config.gaussian_window,
                                    segmentation_approach=config.segmentation_approach,
                                    match_steps=match_steps,
                                    mrms_path=config.mrms_path,
                                    mrms_variable=config.mrms_variable,
                                    mrms_watershed_params=config.mrms_watershed_params,
                                    single_step=config.single_step,
                                    mask_file=mask_file,
                                    patch_radius=patch_radius)
        if config.train:
            print("Find obs tracks", run_date, member)
            mrms_tracks = track_proc.find_mrms_tracks()
        
        print("Find model tracks", run_date, member)
        if patch_radius is None:
            model_tracks = track_proc.find_model_tracks()
        else:
            model_tracks = track_proc.find_model_patch_tracks()
        
        if model_tracks:
            print(run_date, member, "Found this many model tracks: {0:d}".format(len(model_tracks)))
            print("Extract model attributes", run_date, member)
            if hasattr(config, "future_variables"):
                future_variables = config.future_variables
            else:
                future_variables = None

            track_proc.extract_model_attributes(model_tracks,
                                                config.storm_variables,
                                                config.potential_variables,
                                                config.tendency_variables,
                                                future_variables=future_variables)
            if config.train and len(model_tracks) > 0:
                if len(mrms_tracks) > 0 and len(model_tracks) > 0:
                    if match_steps:
                        track_pairings = track_proc.match_track_steps(model_tracks, mrms_tracks)
                        track_proc.match_hail_size_step_distributions(model_tracks, mrms_tracks, track_pairings)
                        track_errors = None
                    else:
                        track_pairings = track_proc.match_tracks(model_tracks, mrms_tracks,
                                                             unique_matches=config.unique_matches,
                                                             closest_matches=config.closest_matches)
                        track_proc.match_size_distributions(model_tracks, mrms_tracks, track_pairings)
                        track_errors = track_proc.calc_track_errors(model_tracks,
                                                                mrms_tracks,
                                                                track_pairings)
                    print("Output data", run_date, member)
                    forecast_data = make_forecast_track_data(model_tracks, run_date, member,
                                                         config, track_proc.model_grid.proj, mrms_tracks, track_errors)
                    if patch_radius is not None:
                        print("Output netCDF", run_date, member)
                        forecast_track_patches_to_netcdf(model_tracks, patch_radius, run_date, member, config)
                    if config.json:
                        print("Output json", run_date, member)
                        forecast_tracks_to_json(model_tracks, run_date, member, config, track_proc.model_grid.proj,
                                            observed_tracks=mrms_tracks,
                                            track_errors=track_errors)
                    obs_data = make_obs_track_data(mrms_tracks, member, run_date, config, track_proc.model_grid.proj)
                    if config.json:
                        obs_tracks_to_json(mrms_tracks, member, run_date, config, track_proc.model_grid.proj)
                    print("Output csv", run_date, member)
                    for table_name, table_data in obs_data.items():
                        csv_filename = config.csv_path + "{0}_{1}_{2}_{3}.csv".format(table_name,
                                                                                  "obs",
                                                                                  member,
                                                                                  run_date.strftime(
                                                                                      config.run_date_format))
                        table_data.to_csv(csv_filename,
                                      na_rep="nan",
                                      float_format="%0.5f",
                                      index=False)
                        os.chmod(csv_filename, 0o666)
                else:
                    forecast_data = make_forecast_track_data(model_tracks, run_date, member, config,
                                                         track_proc.model_grid.proj)
                    if patch_radius is not None:
                        forecast_track_patches_to_netcdf(model_tracks, patch_radius, run_date, member, config)
                    if config.json:
                        forecast_tracks_to_json(model_tracks, run_date, member, config, track_proc.model_grid.proj)
            elif len(model_tracks) > 0:
                print(run_date, member, "Make Forecast Track Data")
                forecast_data = make_forecast_track_data(model_tracks, run_date, member, config, track_proc.model_grid.proj)
                if patch_radius is not None:
                    print(run_date, member, "Track Data to netCDF")
                    forecast_track_patches_to_netcdf(model_tracks, patch_radius, run_date, member, config)
                if config.json:
                    forecast_tracks_to_json(model_tracks, run_date, member, config, track_proc.model_grid.proj)
        else:
            print('No {0} {1} modeled tracks found'.format(run_date,member))
            forecast_data = {}

        for table_name, table_data in forecast_data.items():
            csv_filename = config.csv_path + "{0}_{1}_{2}_{3}.csv".format(table_name,
                                                                          config.ensemble_name,
                                                                          member,
                                                                          run_date.strftime(config.run_date_format))
            print("Output csv file " + csv_filename)
            table_data.to_csv(csv_filename,
                              na_rep="nan",
                              float_format="%0.5f",
                              index=False)
            os.chmod(csv_filename, 0o666)

    except Exception as e:
        print(traceback.format_exc())
        raise e
    return


def process_observed_tracks(run_date, member, config):
    """
    Find forecast and observed tracks for one run of a storm-scale ensemble member.

    Args:
        run_date: datetime object containing the date of the model run
        member: name of the ensemble member
        config: Config object containing model parameters
    """
    try:
        print("Starting", run_date, member)
        start_date = run_date + timedelta(hours=config.start_hour)
        end_date = run_date + timedelta(hours=config.end_hour)
        if hasattr(config, "mask_file"):
            mask_file = config.mask_file
        else:
            mask_file = None

        track_proc = TrackProcessor(run_date,
                                    start_date,
                                    end_date,
                                    config.ensemble_name,
                                    member,
                                    config.watershed_variable,
                                    config.model_path,
                                    config.model_map_file,
                                    config.model_watershed_params,
                                    config.object_matcher_params,
                                    config.track_matcher_params,
                                    config.size_filter,
                                    config.gaussian_window,
                                    mrms_path=config.mrms_path,
                                    mrms_variable=config.mrms_variable,
                                    mrms_watershed_params=config.mrms_watershed_params,
                                    single_step=config.single_step,
                                    mask_file=mask_file)

        print("Find obs tracks", run_date, member)
        mrms_tracks = track_proc.find_mrms_tracks()
        if len(mrms_tracks) > 0:
            obs_data = make_obs_track_data(mrms_tracks, member, run_date, config, track_proc.model_grid.proj)
            if config.json:       
                obs_tracks_to_json(mrms_tracks, member, run_date, config, track_proc.model_grid.proj)
            # if not os.access(config.csv_path + run_date.strftime("%Y%m%d"), os.R_OK):
            #    try:
            #        os.mkdir(config.csv_path + run_date.strftime("%Y%m%d"))
            #    except:
            #        print config.csv_path + run_date.strftime("%Y%m%d") + " already exists"
            for table_name, table_data in obs_data.items():
                csv_filename = config.csv_path + "{0}_{1}_{2}_{3}.csv".format(table_name,
                                                                              "obs",
                                                                              member,
                                                                              run_date.strftime("%Y%m%d"))
                table_data.to_csv(csv_filename,
                                  na_rep="nan",
                                  float_format="%0.5f",
                                  index=False)
                os.chmod(csv_filename, 0o666)
    except Exception as e:
        print(traceback.format_exc())
        raise e
    return


def rematch_ensemble_tracks(run_date, member, config):
    """
    Loads existing forecast and observed tracks from geoJSON files, applies a new matching scheme,
    and outputs the new matches to csv files.

    Args:
        run_date: datetime.datetime
            Start date of the model run
        member: str
            Name of the ensemble member
        config: hagelslag.util.Config
            Config object
    Returns:

    """
    try:
        print("Starting", run_date, member)
        start_date = run_date + timedelta(hours=config.start_hour)
        end_date = run_date + timedelta(hours=config.end_hour)
        track_proc = TrackProcessor(run_date,
                                    start_date,
                                    end_date,
                                    config.ensemble_name,
                                    member,
                                    config.watershed_variable,
                                    config.model_path,
                                    config.model_map_file,
                                    config.model_watershed_params,
                                    config.object_matcher_params,
                                    config.track_matcher_params,
                                    config.size_filter,
                                    config.gaussian_window,
                                    mrms_path=config.mrms_path,
                                    mrms_variable=config.mrms_variable,
                                    mrms_watershed_params=config.mrms_watershed_params,
                                    single_step=config.single_step)
        model_tracks = track_proc.load_model_tracks(config.geojson_path)
        mrms_tracks = track_proc.load_mrms_tracks(config.geojson_path)
        if len(model_tracks) > 0 and len(mrms_tracks) > 0:
            track_pairings = track_proc.match_tracks(model_tracks, mrms_tracks,
                                                     unique_matches=config.unique_matches)
            track_proc.match_size_distributions(model_tracks, mrms_tracks, track_pairings)
            track_errors = track_proc.calc_track_errors(model_tracks,
                                                        mrms_tracks,
                                                        track_pairings)
            print("Output data", run_date, member)
            forecast_data = make_forecast_track_data(model_tracks, run_date, member,
                                                     config, track_proc.model_grid.proj, mrms_tracks, track_errors)
            obs_data = make_obs_track_data(mrms_tracks, member, run_date, config, track_proc.model_grid.proj,
                                           )
            for table_name, table_data in obs_data.items():
                table_data.to_csv(config.csv_path + "{0}_{1}_{2}_{3}.csv".format(table_name,
                                                                                 "obs",
                                                                                 member,
                                                                                 run_date.strftime("%Y%m%d")),
                                  na_rep="nan",
                                  float_format="%0.5f",
                                  index=False)
        elif len(model_tracks) > 0:
            forecast_data = make_forecast_track_data(model_tracks, run_date, member, config, track_proc.model_grid.proj)
        else:
            forecast_data = {}
        for table_name, table_data in forecast_data.items():
            table_data.to_csv(config.csv_path + "{0}_{1}_{2}_{3}.csv".format(table_name,
                                                                             config.ensemble_name,
                                                                             member,
                                                                             run_date.strftime("%Y%m%d")),
                              na_rep="nan",
                              float_format="%0.5f",
                              index=False)
    except Exception as e:
        print(traceback.format_exc())
        raise e
    return


def make_forecast_track_data(forecast_tracks, run_date, member, config, proj, observed_tracks=None, track_errors=None):
    """
    Calculate statistics about each model variable from the forecast track files and output the information to csv.

    Args:
        forecast_tracks: list of storm tracks found in the forecast model
        run_date: datetime object with the start date and time of the model run
        member: name of the ensemble member
        config: Config object containing output parameters
        proj: PyProj object
        observed_tracks: list of storm trakcs found in the observation grid
        track_errors: pandas dataframe containing track error information

    Returns:
        A dictionary of pandas DataFrames that contain information about each track as well as the individual
        steps within each track.
    """
    ensemble_name = config.ensemble_name

    forecast_total_track_columns = ["Track_ID", "Run_Date", "Start_Date", "End_Date",
                                    "Duration", "Ensemble_Name",
                                    "Ensemble_Member", "Object_Variable", "Obs_Track_ID",
                                    "Translation_Error_X", "Translation_Error_Y",
                                    "Start_Time_Error", "End_Time_Error"]
    forecast_step_track_columns = ["Step_ID", "Track_ID", "Ensemble_Name", "Ensemble_Member",
                                   "Run_Date", "Valid_Date",
                                   "Forecast_Hour", "Valid_Hour_UTC", "Duration",
                                   "Centroid_Lon", "Centroid_Lat",
                                   "Centroid_X", "Centroid_Y",
                                   "Storm_Motion_U", "Storm_Motion_V"]
    forecast_variables = config.storm_variables + [v + "-potential" for v in config.potential_variables]
    if hasattr(config, "tendency_variables"):
        forecast_variables.extend([v + "-tendency" for v in config.tendency_variables])
    if hasattr(config, "future_variables"):
        forecast_variables.extend([v + "-future" for v in config.future_variables])
    var_stats = []
    for var in forecast_variables:
        for stat in config.variable_statistics:
            var_stats.append(var + "_" + stat)
    if hasattr(config, "shape_variables"):
        for var in config.shape_variables:
            var_stats.append(var)
    forecast_step_track_columns = forecast_step_track_columns + var_stats
    if hasattr(config, "label_type"):
        if config.label_type == "gamma":
            forecast_step_track_columns += ["Matched", "Max_Hail_Size", "Num_Matches", "Shape", "Location", "Scale"]
        else:
            forecast_step_track_columns += ["Hail_Size"]
    else:
        forecast_step_track_columns += ["Hail_Size"]
    forecast_data = dict()
    forecast_data['track_total'] = pd.DataFrame(columns=forecast_total_track_columns)
    forecast_data['track_step'] = pd.DataFrame(columns=forecast_step_track_columns)
    track_step_count = 0
    for f, forecast_track in enumerate(forecast_tracks):
        track_id = "{0}_{1}_{2}_{3:02d}_{4:02d}_{5:03d}".format(member,
                                                                config.watershed_variable,
                                                                run_date.strftime("%Y%m%d-%H%M"),
                                                                forecast_track.start_time,
                                                                forecast_track.end_time,
                                                                f,
                                                                )
        start_date = run_date + timedelta(seconds=3600 * int(forecast_track.start_time))
        end_date = run_date + timedelta(seconds=3600 * int(forecast_track.end_time))
        duration = (end_date - start_date).total_seconds() / 3600.0 + 1
        obs_track_id = "None"
        if config.train and track_errors is not None:
            if not np.isnan(track_errors.ix[f, 0]):
                obs_track_num = track_errors.ix[f, 0]
                obs_track_id = "obs_{0}_{1}_{2:02d}_{3:02d}_{4:03d}".format(member,
                                                                            run_date.strftime("%Y%m%d-%H%M"),
                                                                            observed_tracks[obs_track_num].start_time,
                                                                            observed_tracks[obs_track_num].end_time,
                                                                            obs_track_num)
            track_error_row = [obs_track_id] + track_errors.ix[f, 1:].tolist()
        else:
            track_error_row = [np.nan] * 5
        forecast_data['track_total'].loc[f] = [track_id, run_date, start_date, end_date, duration,
                                               ensemble_name, member,
                                               config.watershed_variable] + track_error_row
        for s, step in enumerate(forecast_track.times):
            step_id = track_id + "_{0:02d}".format(s)
            step_date = run_date + timedelta(seconds=3600 * int(step))
            valid_hour_utc = step_date.hour
            step_duration = s + 1
            centroid_x, centroid_y = forecast_track.center_of_mass(step)
            centroid_lon, centroid_lat = proj(centroid_x, centroid_y, inverse=True)
            u_motion = forecast_track.u[s]
            v_motion = forecast_track.v[s]
            var_stat_vals = []
            for attribute in forecast_variables:
                for statistic in config.variable_statistics:
                    var_stat_vals.append(forecast_track.calc_attribute_statistic(attribute, statistic, step))
            if hasattr(config, "shape_variables"):
                var_stat_vals.extend(forecast_track.calc_shape_step(config.shape_variables, step))
            record = [step_id, track_id, ensemble_name, member, run_date, step_date, step, valid_hour_utc,
                      step_duration, centroid_lon, centroid_lat, centroid_x, centroid_y,
                      u_motion, v_motion] + var_stat_vals
            if config.unique_matches:
                if forecast_track.observations is not None:
                    if hasattr(config, "label_type"):
                        if config.label_type == "gamma":
                            hail_label = forecast_track.observations.loc[step].values.tolist()
                        else:
                            hail_label = [forecast_track.observations.loc[step, "Max_Hail_Size"]]
                    else:
                        hail_label = [forecast_track.observations[s]]
                else:
                    if hasattr(config, "label_type"):
                        if config.label_type == "gamma":
                            hail_label = [0, 0, 0, 0, 0, 0]
                        else:
                            hail_label = [0]
                    else:
                        hail_label = [0]
                forecast_data['track_step'].loc[track_step_count] = record + hail_label
                track_step_count += 1
            else:
                if forecast_track.observations is not None:
                    num_labels = len(forecast_track.observations)
                    for l in range(num_labels):
                        if config.label_type == "gamma":
                            hail_label = forecast_track.observations[l].loc[step].values.tolist()
                        else:
                            hail_label = [forecast_track.observations[l].loc[step, "Max_Hail_Size"]]
                        forecast_data['track_step'].loc[track_step_count] = record + hail_label
                        track_step_count += 1
                else:
                    if config.label_type == "gamma":
                        hail_label = [0, 0, 0, 0, 0, 0]
                    else:
                        hail_label = [0]
                    forecast_data['track_step'].loc[track_step_count] = record + hail_label
                    track_step_count += 1
    return forecast_data


def forecast_tracks_to_json(forecast_tracks, run_date, member, config, proj, observed_tracks=None, track_errors=None):
    """
    Write each forecast storm track to a geoJSON file.

    Args:
        forecast_tracks (list): List of STObjects containing forecast track information
        run_date (datetime.datetime):  Date of the model run
        member (str): Name of the ensemble member being processed
        config: Config object
        proj: pyproj object with map projection information for model grid
        observed_tracks: List of STObjects for each observed storm track
        track_errors: DataFrame containing information about space and time offsets between forecast and observed tracks
    """
    ensemble_name = config.ensemble_name
    for f, forecast_track in enumerate(forecast_tracks):
        track_id = "{0}_{1}_{2}_{3:02d}_{4:02d}_{5:03d}".format(member,
                                                                config.watershed_variable,
                                                                run_date.strftime("%Y%m%d-%H%M"),
                                                                forecast_track.start_time,
                                                                forecast_track.end_time,
                                                                f,
                                                                )
        start_date = run_date + timedelta(hours=int(forecast_track.start_time))
        end_date = run_date + timedelta(hours=int(forecast_track.end_time))
        duration = (end_date - start_date).total_seconds() / 3600.0 + 1
        obs_track_id = "None"
        if config.train and track_errors is not None:
            if not np.isnan(track_errors.ix[f, 0]):
                obs_track_num = track_errors.ix[f, 0]
                obs_track_id = "obs_{0}_{1}_{2:02d}_{3:02d}_{4:03d}".format(member,
                                                                            run_date.strftime("%Y%m%d-%H%M"),
                                                                            observed_tracks[obs_track_num].start_time,
                                                                            observed_tracks[obs_track_num].end_time,
                                                                            obs_track_num)
        path_parts = [run_date.strftime("%Y%m%d"), member]
        full_path = []
        for part in path_parts:
            full_path.append(part)
            if not os.access(config.geojson_path + "/".join(full_path), os.R_OK):
                try:
                    os.mkdir(config.geojson_path + "/".join(full_path))
                    os.chmod(config.geojson_path + "/".join(full_path), 0o777)
                except OSError:
                    print("directory already created")

        json_filename = config.geojson_path + "/".join(full_path) + \
                        "/{0}_{1}_{2}_model_track_{3:03d}.json".format(ensemble_name,
                                                                       run_date.strftime("%Y%m%d"),
                                                                       member,
                                                                       f)
        json_metadata = dict(id=track_id,
                             ensemble_name=ensemble_name,
                             ensemble_member=member,
                             duration=duration)
        if config.train and track_errors is not None:
            json_metadata['obs_track_id'] = obs_track_id
        forecast_track.to_geojson(json_filename, proj, json_metadata)
        os.chmod(json_filename, 0o666)


def forecast_track_patches_to_netcdf(forecast_tracks, patch_radius, run_date, member, config):
    ensemble_name = config.ensemble_name
    patch_count = 0
    for f, forecast_track in enumerate(forecast_tracks):
        patch_count += forecast_track.times.size
    out_filename = join(config.nc_path, "{0}_{1}_{2}_model_patches.nc".format(ensemble_name,
                                                                              run_date.strftime(config.run_date_format),
                                                                              member))
    out_file = Dataset(out_filename, "w")
    out_file.createDimension("p", patch_count)
    out_file.createDimension("row", patch_radius * 2)
    out_file.createDimension("col", patch_radius * 2)
    out_file.createVariable("p", "i4", ("p",))
    out_file.createVariable("row", "i4", ("row",))
    out_file.createVariable("col", "i4", ("col",))
    out_file.variables["p"][:] = np.arange(patch_count)
    out_file.variables["row"][:] = np.arange(patch_radius * 2)
    out_file.variables["col"][:] = np.arange(patch_radius * 2)
    out_file.Conventions = "CF-1.6"
    out_file.title = "{0} Storm Patches for run {1} member {2}".format(ensemble_name,
                                                                       run_date.strftime(config.run_date_format),
                                                                       member)
    out_file.object_variable = config.watershed_variable
    meta_variables = ["lon", "lat", "i", "j", "x", "y", "masks"]
    meta_units = ["degrees_east", "degrees_north", "", "", "m", "m", ""]
    center_vars = ["time", "centroid_lon", "centroid_lat", "centroid_i", "centroid_j", "track_id", "track_step"]
    center_units = ["hours since {0}".format(run_date.strftime("%Y-%m-%d %H:%M:%S")),
                    "degrees_east",
                    "degrees_north",
                    "",
                    "",
                    "",
                    ""]

    label_columns = ["Matched", "Max_Hail_Size", "Num_Matches", "Shape", "Location", "Scale"]
    for m, meta_variable in enumerate(meta_variables):
        if meta_variable in ["i", "j", "masks"]:
            dtype = "i4"
        else:
            dtype = "f4"
        m_var = out_file.createVariable(meta_variable, dtype, ("p", "row", "col"), complevel=1, zlib=True)
        m_var.long_name = meta_variable
        m_var.units = meta_units[m]
    for c, center_var in enumerate(center_vars):
        if center_var in ["time", "track_id", "track_step"]:
            dtype = "i4"
        else:
            dtype = "f4"
        c_var = out_file.createVariable(center_var, dtype, ("p",), zlib=True, complevel=1)
        c_var.long_name = center_var
        c_var.units = center_units[c]
    for storm_variable in config.storm_variables:
        s_var = out_file.createVariable(storm_variable + "_curr", "f4", ("p", "row", "col"), complevel=1, zlib=True)
        s_var.long_name = storm_variable
        s_var.units = ""
    for potential_variable in config.potential_variables:
        p_var = out_file.createVariable(potential_variable + "_prev", "f4", ("p", "row", "col"),
                                        complevel=1, zlib=True)
        p_var.long_name = potential_variable
        p_var.units = ""
    if hasattr(config, "future_variables"):
        for future_variable in config.future_variables:
            f_var = out_file.createVariable(future_variable + "_future", "f4", ("p", "row", "col"),
                                            complevel=1, zlib=True)
            f_var.long_name = future_variable
            f_var.units = ""
    if config.train:
        for label_column in label_columns:
            if label_column in ["Matched", "Num_Matches"]:
                dtype = "i4"
            else:
                dtype = "f4"
            l_var = out_file.createVariable(label_column, dtype, ("p",), zlib=True, complevel=1)
            l_var.long_name = label_column
            l_var.units = ""
    out_file.variables["time"][:] = np.concatenate([f_track.times for f_track in forecast_tracks])
    for c_var in ["lon", "lat"]:
        out_file.variables["centroid_" + c_var][:] = np.concatenate([np.array(f_track.attributes[c_var])[:,
                                                                     patch_radius, patch_radius]
                                                                     for f_track in forecast_tracks])
    for c_var in ["i", "j"]:
        out_file.variables["centroid_" + c_var][:] = np.concatenate([np.array(getattr(f_track, c_var))[:,
                                                                     patch_radius, patch_radius]
                                                                     for f_track in forecast_tracks])
    out_file.variables["track_id"][:] = np.concatenate([[f] * f_track.times.size
                                                        for f, f_track in enumerate(forecast_tracks)])
    out_file.variables["track_step"][:] = np.concatenate([np.arange(1, f_track.times.size + 1)
                                                          for f_track in forecast_tracks])
    for meta_var in meta_variables:
        if meta_var in ["lon", "lat"]:
            out_file.variables[meta_var][:] = np.vstack([f_track.attributes[meta_var] for f_track in forecast_tracks])
        else:
            out_file.variables[meta_var][:] = np.vstack([getattr(f_track, meta_var) for f_track in forecast_tracks])
    for storm_variable in config.storm_variables:
        out_file.variables[storm_variable + "_curr"][:] = np.vstack([f_track.attributes[storm_variable]
                                                                     for f_track in forecast_tracks])
    for p_variable in config.potential_variables:
        out_file.variables[p_variable + "_prev"][:] = np.vstack([f_track.attributes[p_variable + "-potential"]
                                                                 for f_track in forecast_tracks])
    if hasattr(config, "future_variables"):
        for f_variable in config.future_variables:
            out_file.variables[f_variable + "_future"][:] = np.vstack([f_track.attributes[f_variable + "-future"]
                                                                 for f_track in forecast_tracks])
    if config.train:
        for label_column in label_columns:
            try:
                out_file.variables[label_column][:] = np.concatenate([f_track.observations[label_column].values
                                                                      for f_track in forecast_tracks])
            except Exception as e:
                out_file.variables[label_column][:] = 0

    # Save configuration dictionary as global attributes.
    for k,v in config.__dict__.items():
        # Don't save attributes that are already netCDF varaible names
        if k=="dates": continue
        if k=="storm_variables": continue
        if k=="potential_variables": continue
        if k=="tendency_variables": continue
        v = str(v)
        # Don't clobber existing attribute.
        if hasattr(out_file, k):
            # If it exists already, add "config_" to beginning and recheck.
            alt_key = "config_" + k
            if hasattr(out_file, alt_key):
                # If alternative key exists too, raise attribute error.
                raise AttributeError("Can't save "+k+":"+" "+v+" as global attribute. It already exists.")
            else:
                setattr(out_file, alt_key, v)
        else:
            setattr(out_file, k, v)

    out_file.close()


def make_obs_track_data(obs_tracks, member, run_date, config, proj, track_type='obs'):
    """
    Calculate statistics on observed storm tracks and out the information as a dataframe.

    Args:
        obs_tracks: List of observed storm tracks
        member: Name of the ensemble member
        run_date: Date of the model run
        config: Config object
        proj: pyproj object with map projection information

    Returns:
        A dictionary of pandas dataframes with total and step information about each observed storm track.
    """
    obs_total_track_columns = ["Obs_Track_ID", "Start_Date", "End_Date", "Duration", "Track_ID"]
    obs_step_track_columns = ["Step_ID", "Obs_Track_ID", "Run_Date", "Valid_Date", "Forecast_Hour",
                              "Valid_Hour_UTC", "Duration",
                              "Centroid_Lon", "Centroid_Lat", "Centroid_X", "Centroid_Y"]
    var_stats = []
    if track_type == 'obs':
        for stat in config.variable_statistics:
            var_stats.append("MESH_" + stat)

    if hasattr(config, "shape_variables"):
        for var in config.shape_variables:
            var_stats.append(var)

    obs_step_track_columns = obs_step_track_columns + var_stats
    obs_data = dict()
    obs_data['track_total'] = pd.DataFrame(columns=obs_total_track_columns)
    obs_data['track_step'] = pd.DataFrame(columns=obs_step_track_columns)
    track_step_count = 0
    for o, obs_track in enumerate(obs_tracks):
        obs_track_id = "obs_{0}_{1}_{2:02d}_{3:02d}_{4:03d}".format(member,
                                                                    run_date.strftime("%Y%m%d-%H%M"),
                                                                    obs_track.start_time,
                                                                    obs_track.end_time,
                                                                    o)
        track_id = "None"
        start_date = run_date + timedelta(seconds=3600 * int(obs_track.start_time))
        end_date = run_date + timedelta(seconds=3600 * int(obs_track.end_time))
        duration = (end_date - start_date).total_seconds() / 3600.0 + 1
        obs_data['track_total'].loc[o] = [obs_track_id, start_date, end_date, duration, track_id]
        for s, step in enumerate(obs_track.times):
            step_id = obs_track_id + "_{0:02d}".format(s)
            step_date = run_date + timedelta(seconds=3600 * int(step))
            valid_hour_utc = step_date.hour
            step_duration = s + 1
            centroid_x, centroid_y = obs_track.center_of_mass(step)
            centroid_lon, centroid_lat = proj(centroid_x, centroid_y, inverse=True)
            var_stat_vals = []
            if track_type == 'obs':
                for statistic in config.variable_statistics:
                    var_stat_vals.append(obs_track.calc_timestep_statistic(statistic, step))
            if hasattr(config, "shape_variables"):
                var_stat_vals.extend(obs_track.calc_shape_step(config.shape_variables, step))

            record = [step_id, obs_track_id, run_date, step_date, step,
                      valid_hour_utc, step_duration, centroid_lon, centroid_lat,
                      centroid_x, centroid_y] + var_stat_vals

            obs_data['track_step'].loc[track_step_count] = record
            track_step_count += 1

    return obs_data


def obs_tracks_to_json(obs_tracks, member, run_date, config, proj):
    """
    Write observed storm track information to geoJSON files.

    Args:
        obs_tracks: List of observed tracks
        member: Name of the ensemble member
        run_date: Date of the model run
        config: Config object
        proj: pyproj map projection

    """
    for o, obs_track in enumerate(obs_tracks):
        obs_track_id = "obs_{0}_{1}_{2:02d}_{3:02d}_{4:03d}".format(member,
                                                                    run_date.strftime(config.run_date_format),
                                                                    obs_track.start_time,
                                                                    obs_track.end_time,
                                                                    o)
        start_date = run_date + timedelta(seconds=3600 * int(obs_track.start_time))
        end_date = run_date + timedelta(seconds=3600 * int(obs_track.end_time))
        duration = (end_date - start_date).total_seconds() / 3600.0 + 1
        path_parts = [run_date.strftime(config.run_date_format), member]
        full_path = []
        for part in path_parts:
            full_path.append(part)
            if not os.access(config.geojson_path + "/".join(full_path), os.R_OK):
                try:
                    os.mkdir(config.geojson_path + "/".join(full_path))
                    os.chmod(config.geojson_path + "/".join(full_path), 0o777)

                except OSError:
                    print("directory already created")

        json_filename = config.geojson_path + "/".join(full_path) + \
                        "/{0}_{1}_{2}_obs_track_{3:03d}.json".format("mesh",
                                                                     run_date.strftime(config.run_date_format),
                                                                     member,
                                                                     o)
        json_metadata = dict(id=obs_track_id,
                             ensemble_member=member,
                             duration=duration)
        obs_track.to_geojson(json_filename, proj, json_metadata)
        os.chmod(json_filename, 0o666)
    return


if __name__ == "__main__":
    main()
