#!/usr/bin/env python3
"""
Generates posterior plots from RapidPE/RIFT results
"""

__author__ = "Caitlin Rose, Vinaya Valsan"

import os

import numpy as np
import logging

import rapidpe_rift_pipe.postscript_utils as postutils
from argparse import ArgumentParser
from scipy.stats import multinomial
from rapid_pe import amrlib
from rapid_pe.amrlib import VALID_TRANSFORMS_MASS
from rapid_pe.amrlib import INVERSE_TRANSFORMS_MASS
from rapid_pe.amrlib import BOUND_CHECK_MASS

print("-------------------Plotting intrinsic posteriors----------------------")

logging.basicConfig(level=logging.INFO)


optp = ArgumentParser()
optp.add_argument("input_dir", help="path to event run dir")
optp.add_argument(
    "--distance-coordinates",
    default=None,
    type=str,
    help="coordinates for intrinsic grid",
)
optp.add_argument("--output-dir", default=None, help="directory to save plots")
optp.add_argument(
    "--sigma1-factor",
    default=0.5,
    type=float,
    help="standard deviation for posterior for param1 is this factor "
    "multiplied to grid size",
)
optp.add_argument(
    "--sigma2-factor",
    default=0.5,
    type=float,
    help="standard deviation for posterior for param2 is this factor"
    "multiplied to grid size",
)
optp.add_argument(
    "--sigma3-factor",
    default=0.5,
    type=float,
    help="standard deviation for posterior for param3 is this factor"
    "multiplied to grid size",
)
optp.add_argument(
    "--sigma4-factor",
    default=0.5,
    type=float,
    help="standard deviation for posterior for param4 is this factor"
    "multiplied to grid size",
)
opts = optp.parse_args()

input_dir = opts.input_dir

results_dir = os.path.join(input_dir, "results/")


distance_coordinates_str = opts.distance_coordinates

distance_coordinates = distance_coordinates_str.split("_")
sigma_str = f'sigma1_{str(opts.sigma1_factor).replace(".","p")}'
sigma_str += f'-sigma2_{str(opts.sigma2_factor).replace(".","p")}'
if len(distance_coordinates) >= 3:
    sigma_str += f'-sigma3_{str(opts.sigma3_factor).replace(".","p")}'
elif len(distance_coordinates) == 4:
    sigma_str += f'-sigma4_{str(opts.sigma4_factor).replace(".","p")}'

print(f"Sigma values: {sigma_str}")
if opts.output_dir:
    output_dir = opts.output_dir
else:
    output_dir = input_dir

summary_plots_dir = os.path.join(output_dir, "summary_plots/")
os.system(f"mkdir -p {summary_plots_dir}")

f_lower = 40

# Get injection/search point
event_info = postutils.event_info(input_dir)
event_param_dict = event_info.get_event_params()
mass1_inj = event_param_dict["mass1"]
mass2_inj = event_param_dict["mass2"]
spin1z_inj = event_param_dict["spin1z"]
spin2z_inj = event_param_dict["spin2z"]

sigma_factor = {}
grid_param_list = distance_coordinates

grid_in_4dimensions = True if spin1z_inj is not None else False
if grid_in_4dimensions:
    if distance_coordinates_str != "mu1_mu2_q_s2z":
        grid_param_list += ["chi_eff", "chi_a"]
sigma_factor = {}
sigma_factor[grid_param_list[0]] = opts.sigma1_factor
sigma_factor[grid_param_list[1]] = opts.sigma2_factor
if grid_in_4dimensions:
    sigma_factor[grid_param_list[2]] = opts.sigma3_factor
    sigma_factor[grid_param_list[3]] = opts.sigma4_factor

# Read results xml files
grid_data_dict = postutils.get_grid_info(input_dir)
if distance_coordinates_str == "mu1_mu2_q_s2z":
    (
        event_param_dict[grid_param_list[0]],
        event_param_dict[grid_param_list[1]],
        event_param_dict[grid_param_list[2]],
        event_param_dict[grid_param_list[3]],
    ) = amrlib.transform_m1m2s1zs2z_mu1mu2qs2z(
        mass1_inj, mass2_inj, spin1z_inj, spin2z_inj
    )
    (
        grid_data_dict[grid_param_list[0]],
        grid_data_dict[grid_param_list[1]],
        grid_data_dict[grid_param_list[2]],
        grid_data_dict[grid_param_list[3]],
    ) = amrlib.transform_m1m2s1zs2z_mu1mu2qs2z(
        grid_data_dict["mass1"],
        grid_data_dict["mass2"],
        grid_data_dict["spin1z"],
        grid_data_dict["spin2z"],
    )
else:
    (
        event_param_dict[grid_param_list[0]],
        event_param_dict[grid_param_list[1]],
    ) = VALID_TRANSFORMS_MASS[frozenset(distance_coordinates_str.split("_"))](
        mass1_inj, mass2_inj
    )
    (
        grid_data_dict[grid_param_list[0]],
        grid_data_dict[grid_param_list[1]],
    ) = VALID_TRANSFORMS_MASS[frozenset(distance_coordinates_str.split("_"))](
        grid_data_dict["mass1"],
        grid_data_dict["mass2"],
    )

use_grid_level = None
sigma_dict = postutils.find_sigma(
    grid_data_dict, grid_param_list, sigma_factor, grid_level=use_grid_level
)


grid_levels = np.unique(grid_data_dict["iteration_level"])

for i, gl in enumerate(grid_levels):
    postutils.plot_grid(
        grid_data_dict,
        "mass1",
        "mass2",
        summary_plots_dir,
        grid_level=i,
        event_info=event_param_dict,
    )
    postutils.plot_grid(
        grid_data_dict,
        distance_coordinates[0],
        distance_coordinates[1],
        summary_plots_dir,
        grid_level=i,
        event_info=event_param_dict,
    )

    if spin1z_inj is not None:
        postutils.plot_grid(
            grid_data_dict,
            "spin1z",
            "spin2z",
            summary_plots_dir,
            grid_level=i,
            event_info=event_param_dict,
        )

    if distance_coordinates_str == "mu1_mu2_q_s2q":
        postutils.plot_grid(
            grid_data_dict,
            "mu1",
            "mu2",
            summary_plots_dir,
            grid_level=i,
            event_info=event_param_dict,
        )
        postutils.plot_grid(
            grid_data_dict,
            "q",
            "spin1z",
            summary_plots_dir,
            grid_level=i,
            event_info=event_param_dict,
        )
        postutils.plot_grid(
            grid_data_dict,
            "q",
            "spin2z",
            summary_plots_dir,
            grid_level=i,
            event_info=event_param_dict,
        )


postutils.plot_grid(
    grid_data_dict,
    "mass1",
    "mass2",
    summary_plots_dir,
    event_info=event_param_dict,
)
postutils.plot_grid(
    grid_data_dict,
    distance_coordinates[0],
    distance_coordinates[1],
    summary_plots_dir,
    event_info=event_param_dict,
)

if spin1z_inj is not None:
    postutils.plot_grid(
        grid_data_dict,
        "spin1z",
        "spin2z",
        summary_plots_dir,
        event_info=event_param_dict,
    )

if distance_coordinates_str == "mu1_mu2_q_s2q":
    postutils.plot_grid(
        grid_data_dict,
        "mu1",
        "mu2",
        summary_plots_dir,
        event_info=event_param_dict,
    )
    postutils.plot_grid(
        grid_data_dict,
        "q",
        "spin1z",
        summary_plots_dir,
        event_info=event_param_dict,
    )
    postutils.plot_grid(
        grid_data_dict,
        "q",
        "spin2z",
        summary_plots_dir,
        event_info=event_param_dict,
    )


def uniform_m1_m2_prior_in_mchirp_eta(mchirp, eta):
    """
    Returns  jacobian  p(mchirp, eta) = d(mass1,mass2)/d(mchirp,eta)
    """
    p = np.abs(
        mchirp * np.power(eta, -6.0 / 5.0) * (1.0 / np.sqrt(1.0 - 4.0 * eta))
    )
    return p


def uniform_m1_m2_prior_in_mchirp_q(mchirp, q):
    """
    Returns  jacobian  p(mchirp, q) = d(mass1,mass2)/d(mchirp,q)
    """
    p = np.abs(mchirp * (1.0 + q) ** 2 / 5 / q ** (6 / 5))
    return p


def uniform_m1_m2_prior_in_mtotal_q(mtotal, q):
    """
    Returns  jacobian  p(mtotal, q) = d(mass1,mass2)/d(mtotal,q)
    """
    p = np.abs(mtotal / (1 + q) ** 2)
    return p


def uniform_m1_m2_prior_in_tau0_tau3(tau0, tau3, f_lower):
    """
    Returns  jacobian  p(tau0, tau3) = d(mass1,mass2)/d(tau0,tau3)
    """
    a3 = np.pi / (8.0 * (np.pi * f_lower) ** (5.0 / 3.0))
    a0 = 5.0 / (256.0 * (np.pi * f_lower) ** (8.0 / 3.0))
    tmp1 = (a0 * tau3) / (a3 * tau0)
    num = a0 * (tmp1) ** (1.0 / 3.0)
    tmp2 = 1 - ((4 * a3) / (tau3 * tmp1 ** (2.0 / 3.0)))
    den = tau0**2.0 * tau3 * np.sqrt(tmp2)
    return np.abs(num / den)


def uniform_m1m2chi1chi2_prior_to_mu1mu2qchi2(mu1, mu2, q, s2z):
    """Return d(mu1, mu2, q, s2z) / d(m1, m2, s1z, s2z)"""
    MsunToTime = 4.92659 * 10.0 ** (
        -6.0
    )  # conversion from solar mass to seconds
    fref_mu = 200.0
    # coefficients of mu1 and mu2
    mu_coeffs = np.array(
        [
            [0.97437198, 0.20868103, 0.08397302],
            [-0.22132704, 0.82273827, 0.52356096],
        ]
    )
    m1, m2, s1z, s2z = amrlib.transform_mu1mu2qs2z_m1m2s1zs2z(mu1, mu2, q, s2z)
    mc = (m1 * m2) ** (3.0 / 5.0) / (m1 + m2) ** (1.0 / 5.0)
    q = m2 / m1
    eta = amrlib.qToeta(q)
    x = np.pi * mc * MsunToTime * fref_mu
    tmp1 = (
        mu_coeffs[0, 2] * mu_coeffs[1, 0] - mu_coeffs[0, 0] * mu_coeffs[1, 2]
    )
    tmp2 = (
        mu_coeffs[0, 2] * mu_coeffs[1, 1] - mu_coeffs[0, 1] * mu_coeffs[1, 2]
    )
    denominator = (
        x
        * 5.0
        * (113.0 + 75.0 * q)
        * (
            252.0 * tmp1 * q * eta ** (-3.0 / 5.0)
            + tmp2 * (743.0 + 2410.0 * q + 743.0 * q**2.0) * x ** (2.0 / 3.0)
        )
    )
    numerator = (
        m1**2.0 * 4128768.0 * q * (1.0 + q) ** 2.0 * x ** (10.0 / 3.0)
    )
    return np.abs(numerator / denominator)


PRIOR_MAP = {
    "mchirp_eta": uniform_m1_m2_prior_in_mchirp_eta,
    "tau0_tau3": uniform_m1_m2_prior_in_tau0_tau3,
    "mchirp_q": uniform_m1_m2_prior_in_mchirp_q,
    "mtotal_q": uniform_m1_m2_prior_in_mtotal_q,
    "mu1_mu2_q_s2q": uniform_m1m2chi1chi2_prior_to_mu1mu2qchi2,
}


def get_posterior_samples(grid_data, sigma, grid_level=None):
    """
    Generate posterior samples for params for the given grid_level
    """
    distance_coordinates = list(sigma.keys())
    sample_dict = {}
    Margll_sel = grid_data["margll"]
    spin_included = True if "spin1z" in grid_data.keys() else False
    grid_it_level = grid_data["iteration_level"]
    grid_index_list = grid_data["grid_id"]
    if grid_level is not None:
        grid_inds = grid_index_list[grid_it_level == grid_level]
        Margll_sel = grid_data["margll"][grid_inds]
        for param in distance_coordinates:
            grid_data[param] = grid_data[param][grid_inds]

    margL_normed = np.exp(Margll_sel - np.max(Margll_sel))
    sum_margL_normed = np.sum(margL_normed)
    margL_normed /= sum_margL_normed
    seed = 12345
    random_state = np.random.RandomState(seed)
    N_mn = multinomial(100000, margL_normed, seed=random_state)
    N = N_mn.rvs(1)[0]
    print(f"Number of samples {N}")
    grid_id = []
    all_random_samples = {param: [] for param in distance_coordinates}
    for i in range(len(margL_normed)):
        random_samples = {param: [] for param in distance_coordinates}
        for param in distance_coordinates:
            random_samples[param] = np.random.normal(
                loc=grid_data[param][i], scale=sigma[param][i], size=N[i]
            )
            all_random_samples[param] = np.append(
                all_random_samples[param], random_samples[param]
            )
        grid_id = np.append(grid_id, i * np.ones(N[i]))
    param1_samples = all_random_samples[distance_coordinates[0]]
    param2_samples = all_random_samples[distance_coordinates[1]]
    if spin_included:
        param3_samples = all_random_samples[distance_coordinates[2]]
        param4_samples = all_random_samples[distance_coordinates[3]]
    if distance_coordinates_str != "mu1_mu2_q_s2z":
        mask = BOUND_CHECK_MASS[distance_coordinates_str](
            param1_samples, param2_samples
        )

        if spin_included:
            mask &= amrlib.check_spins(param3_samples)
            mask &= amrlib.check_spins(param4_samples)
        for param in distance_coordinates:
            all_random_samples[param] = all_random_samples[param][mask]

        m1_samples, m2_samples = INVERSE_TRANSFORMS_MASS[
            VALID_TRANSFORMS_MASS[
                frozenset(distance_coordinates_str.split("_"))
            ]
        ](param1_samples, param2_samples)
        prior = PRIOR_MAP[distance_coordinates_str](
            param1_samples, param2_samples
        )
        sample_dict["mass1"] = m1_samples
        sample_dict["mass2"] = m2_samples

        sample_dict[distance_coordinates[0]] = param1_samples
        sample_dict[distance_coordinates[1]] = param2_samples
        sample_dict["prior"] = prior
        if spin_included:
            (
                spin1z_samples,
                spin2z_samples,
            ) = amrlib.transform_chi_eff_chi_a_s1zs2z(
                m1_samples, m2_samples, param3_samples, param4_samples
            )

            sample_dict["chi_eff"] = param3_samples
            sample_dict["chi_a"] = param4_samples
            sample_dict["spin1z"] = spin1z_samples
            sample_dict["spin2z"] = spin2z_samples
    else:
        mask = amrlib.check_q(param3_samples)
        mask &= amrlib.check_spins(param4_samples)
        mu1_samples = np.array(param3_samples[mask])
        mu2_samples = np.array(param4_samples[mask])
        q_samples = np.array(param1_samples[mask])
        spin2z_samples = np.array(param2_samples[mask])

        (
            m1_samples,
            m2_samples,
            spin1z_samples,
            spin2z_samples,
        ) = amrlib.transform_mu1mu2qs2z_m1m2s1zs2z(
            mu1_samples, mu2_samples, q_samples, spin2z_samples
        )

        chi_eff_samples, chi_a_samples = amrlib.transform_s1zs2z_chi_eff_chi_a(
            m1_samples,
            m2_samples,
            spin1z_samples,
            spin2z_samples,
        )
        mu1mu2qs2z_prior = uniform_m1m2chi1chi2_prior_to_mu1mu2qchi2(
            mu1_samples, mu2_samples, q_samples, spin2z_samples
        )
        sample_dict["mu1"] = mu1_samples
        sample_dict["mu2"] = mu2_samples
        sample_dict["q"] = q_samples
        sample_dict["spin2z"] = spin2z_samples
        sample_dict["spin1z"] = spin1z_samples
        sample_dict["chi_eff"] = chi_eff_samples
        sample_dict["chi_a"] = chi_a_samples
        sample_dict["mass1"] = m1_samples
        sample_dict["mass2"] = m2_samples
        sample_dict["prior"] = mu1mu2qs2z_prior
    sample_dict["grid_id"] = grid_id[mask]
    return sample_dict


sample_dict = get_posterior_samples(
    grid_data_dict,
    sigma_dict,
    grid_level=use_grid_level,
)

posterior_plot_axis = distance_coordinates + ["mass1", "mass2"]
if spin1z_inj is not None:
    posterior_plot_axis += ["spin1z", "spin2z", "chi_eff", "chi_a"]

for param in posterior_plot_axis:

    postutils.plot_posterior(
        sample_dict,
        param,
        plot_dir=input_dir + "/summary_plots/",
        event_info=event_param_dict,
    )

postutils.save_m1m2_posterior_samples(
    sample_dict,
    input_dir + "/summary_plots/",
)

print(f"All plots saved in {output_dir}")
