#!/bin/env python

import argparse
import json
import sys

import numpy as np

module_path_2 = '/home/max.lalleman/public_html/Code/New_Clone_pyGWB/pygwb'
sys.path.insert(0,module_path_2)

import bilby

import pygwb.pe as pe
from pygwb.baseline import Baseline
from pygwb.detector import Interferometer


def load_npz(fname_path):
    data = np.load(fname_path)
    ptEst_ff = np.real(data["point_estimate_spectrum"])
    sigma_ff = data["sigma_spectrum"]
    goodInds = np.logical_and(np.invert(np.isnan(ptEst_ff)), ptEst_ff!=0, sigma_ff!=0)
    if "frequencies" not in data.keys():
        freqs = None
    else:
        freqs = data["frequencies"]
        freqs = freqs[goodInds]
    Y_f = ptEst_ff[goodInds]
    sigma_f = sigma_ff[goodInds]
    return (freqs,Y_f,sigma_f)
    
def main():
    pe_parser = argparse.ArgumentParser(add_help=True)
    pe_parser.add_argument(
        "--path_to_file", help="Path to data file (or pickled baseline) to use for analysis.", action="store", type=str, required=True
    )
    pe_parser.add_argument(
        "--ifos", help="List of names of two interferometers for which you want to run PE, default is H1 and L1. Not needed when running from a pickled baseline.", action="store", type=str, required=False, default = "None"
    )
    pe_parser.add_argument(
        "--Model", help="Model for which to run the PE. \n Default is \"Power-law\". ", action="store", type=str, required=False, default = "Power-Law"
    )
    pe_parser.add_argument(
        "--Model_Parameters_Dictionary_Priors", help="The required parameters for the model and their priors. \n Default value is power-law parameters omega_ref with LogUniform bilby prior from 1e-11 to 1e-8 and alpha with Uniform bilby prior from -4 to 4.", type=str, action="store", required=False, default = "None"
    )
    pe_parser.add_argument(
        "--Non_Prior_Arguments", help="A dictionary with the parameters of the model that are not associated with a prior, such as the reference frequency for Power-Law. Default value is reference frequency at 25 Hz for the power-law model.", type=str, action="store", required=False, default="None"
    )
    pe_parser.add_argument(
        "--output_dir", help="Output directory of the PE (sampler). Default: ./PE_Output.", type=str, action="store", required=False, default='./PE_Output'
    )
    pe_parser.add_argument(
        "--injection_parameters", help="The injected parameters.", required= False, default="None", type=str, action="store"
    )
    pe_parser.add_argument(
        "--quantiles", help="The quantiles used for plotting in plot corner, default is [0.05, 0.95].", type=str, action="store", required=False, default="None"
    )
    pe_parser.add_argument(
        "--f0", help="If no frequencies are saved in the loaded npz file, you can give f0, fhigh and df to the script. This is f0", type=float, action="store", default=None
    )
    pe_parser.add_argument(
        "--fhigh", help="If no frequencies are saved in the loaded npz file, you can give f0, fhigh and df to the script. This is fhigh", type=float, action="store", default=None
    )
    pe_parser.add_argument(
        "--df", help="If no frequencies are saved in the loaded npz file, you can give f0, fhigh and df to the script. This is df", type=float, action="store", default=None
    )
    
    script_args = pe_parser.parse_known_args()[0]
    
    dictionary_Models = {"Power-Law": pe.PowerLawModel, 
                         "Broken-Power-Law": pe.BrokenPowerLawModel,
                         "Triple-Broken-Power-Law": pe.TripleBrokenPowerLawModel,
                         "Smooth-Broken-Power-Law": pe.SmoothBrokenPowerLawModel,
                         "Schumann": pe.SchumannModel,
                         "Parity-Violation": pe.PVPowerLawModel,
                         "Parity-Violation-2": pe.PVPowerLawModel2} #"TVS-Power-Law": pe.TVSPowerLawModel
                         
    dictionary_parameters = {"Power-Law": ["omega_ref","alpha", "fref"],
                             "Broken-Power-Law": ["omega_ref", "fbreak", "alpha_1", "alpha_2"],
                             "Triple-Broken-Power-Law": ["omega_ref", "alpha_1", "alpha_2", "alpha_3", "fbreak1", "fbreak2"],
                             "Smooth-Broken-Power-Law": ["omega_ref", "fbreak", "alpha_1", "alpha_2", "delta"],
                             "Schumann": ["kappa_H1", "kappa_L1", "beta_L1", "beta_H1"],
                             "Parity-Violation": ["omega_ref", "alpha", "Pi"],
                             "Parity-Violation-2": ["omega_ref", "alpha", "beta"]}
    
    if script_args.Model not in dictionary_Models.keys():
        raise ValueError(
            "The model you provided is not supported."
        )
    else:
        class_model = dictionary_Models[script_args.Model]
        
    if script_args.Model_Parameters_Dictionary_Priors == 'None': 
        dict_loaded = {"omega_ref":bilby.core.prior.LogUniform(1e-11,1e-8,"$\Omega_{ref}$"),"alpha":bilby.core.prior.Uniform(-4,4,"$\\alpha$")}
    else:
        dict_loaded = json.loads(script_args.Model_Parameters_Dictionary_Priors)
        
    if script_args.Non_Prior_Arguments == 'None':
        extra_kwargs = {"fref": 25}
    else:
        extra_kwargs = json.loads(script_args.Non_Prior_Arguments)
    
    all_parameters_dict = dict_loaded.copy()
    all_parameters_dict.update(extra_kwargs)
    
    for key in all_parameters_dict.keys():
        if not key in dictionary_parameters[script_args.Model]:
            raise AttributeError(
                f"The parameter {key} you provided is not associated with the given model. \n These are the required parameters: {dictionary_parameters[script_args.Model]}."
            )
      
    if not all(key in all_parameters_dict.keys() for key in dictionary_parameters[script_args.Model]):
        raise ValueError(
            f"Not all required parameters of the model are present in one of the parameters dictionaries. \n These are the required parameters: {dictionary_parameters[script_args.Model]}."
        )
    
    if script_args.path_to_file.endswith(("npz")):
        
        frequencies, point_estimate_spectrum, sigma_spectrum = load_npz(script_args.path_to_file)
        
        if frequencies is None:
            frequencies = np.arange(script_args.f0, script_args.fhigh + script_args.df, script_args.df)
        
        point_estimates = [point_estimate_spectrum]
        sigmas = [sigma_spectrum]
        
        if script_args.ifos == "None":
            ifos = ['H1', 'L1']
        else:
            ifos = json.loads(script_args.ifos)
        
        ifo_1 = Interferometer.get_empty_interferometer(ifos[0])
        ifo_2 = Interferometer.get_empty_interferometer(ifos[1])
    
        base_12 = Baseline.from_interferometers([ifo_1, ifo_2])
    
        if not base_12._orf_polarization_set:
            base_12.orf_polarization = "tensor" # script._args.polarization
    
        base_12.frequencies = frequencies

        base_12.point_estimate_spectrum = point_estimate_spectrum
        base_12.sigma_spectrum = sigma_spectrum
    
    elif script_args.path_to_file.endswith((".p", ".pickle")):
        
        base_12 = Baseline.load_from_pickle(script_args.path_to_file)
    
    else:
        raise TypeError(
            "The provided file format is currently not supported, try a npz file or a pickled baseline instead."
        )
    
    
    kwargs = {"baselines":[base_12], "model_name":script_args.Model} #, "fref":script_args.fref}
    kwargs.update(extra_kwargs)
    model = class_model(**kwargs)
    priors = dict_loaded
    if script_args.injection_parameters == "None":
        injection_parameters = dict(omega_ref=10**(-8.67985), alpha=2/3)
    else:
        injection_parameters = json.loads(script_args.injection_parameters)
    
    range_list = [(0,0) for i in range(len(priors))]
    for index,key in enumerate(priors.keys()):
        minimum_range = priors[key].minimum
        maximum_range = priors[key].maximum
        tuple_range = (minimum_range, maximum_range)
        range_list[index] = tuple_range
    
    hlv=bilby.run_sampler(likelihood=model,priors=priors,sampler='dynesty', npoints=1000, walks=10, maxmcmc=10000, 
    outdir=script_args.output_dir,label= 'hlv', injection_parameters = injection_parameters) #walks=10, maxmcmc = 10000, npoints= 1000

    if script_args.quantiles == "None":
        quantiles = [0.05, 0.95]
    else:
        quantiles = json.loads(script_args.quantiles)
        
    print(quantiles)

    hlv.plot_corner(show_titles=True, title_fmt='.3g', use_math_text=True, quantiles=quantiles, range=range_list)
    

if __name__ == "__main__":
    main()
