import json
import re
import glob
import h5py
import logging
import numpy as np
from rapid_pe import amrlib

import matplotlib.pyplot as plt
from ligo.lw import utils, lsctables, ligolw

logging.basicConfig(level=logging.INFO)


class event_info:
    def __init__(self, rundir):
        self.rundir = rundir

    def load_event_info(self):
        """
        read event_info_dict.txt
        """
        with open(self.rundir + "/event_info_dict.txt") as f:
            event_info_dict = json.load(f)
        return event_info_dict

    def get_event_params(self):
        event_info_dict = self.load_event_info()
        intrinsic_param_inj = event_info_dict["intrinsic_param"]
        mass1_inj = re.search('mass1=(.+?)"', intrinsic_param_inj)
        mass1_inj = mass1_inj.group(1)
        mass1_inj = float(mass1_inj)
        mass2_inj = re.search('mass2=(.+?)"', intrinsic_param_inj)
        mass2_inj = mass2_inj.group(1)
        mass2_inj = float(mass2_inj)
        event_params = {}
        event_params["mass1"] = mass1_inj
        event_params["mass2"] = mass2_inj
        try:
            spin1z_inj = re.search('spin1z=(.+?)"', intrinsic_param_inj)
            spin1z_inj = spin1z_inj.group(1)
            spin1z_inj = float(spin1z_inj)
            spin2z_inj = re.search('spin2z=(.+?)"', intrinsic_param_inj)
            spin2z_inj = spin2z_inj.group(1)
            spin2z_inj = float(spin2z_inj)
            event_params["spin1z"] = spin1z_inj
            event_params["spin2z"] = spin2z_inj
            (
                event_params["chi_eff"],
                event_params["chi_a"],
            ) = amrlib.transform_s1zs2z_chi_eff_chi_a(
                mass1_inj, mass2_inj, spin1z_inj, spin2z_inj
            )
        except AttributeError:
            event_params["spin1z"] = None
            event_params["spin2z"] = None
            logging.info("No Spin information found in event_info_dict")
            pass
        return event_params


def get_grid_info(rundir):
    results_dir = rundir + "/results/"
    all_xml = glob.glob(results_dir + "/ILE_iteration_*-MASS_SET_*-0.xml.gz")
    if len(all_xml) == 0:
        all_xml = glob.glob(
            results_dir + "/ILE_iteration_*-MASS_SET_*_0_.xml.gz"
        )
    print(f"Found {len(all_xml)} sample files")
    iterations = [
        xmlfile[
            xmlfile.find("ILE_iteration") : xmlfile.find("ILE_iteration")
            + len("ILE_iteration_0")
        ]
        for xmlfile in all_xml
    ]

    grid_levels = np.sort(np.unique(iterations))
    keys = {
        "mass1",
        "mass2",
        "spin1z",
        "spin2z",
        "chi_eff",
        "chi_a",
        "margll",
        "grid_id",
        "iteration_level",
        "filename",
    }
    data_dict = {key: [] for key in keys}
    for i, gl in enumerate(grid_levels):
        xml_files = glob.glob(results_dir + gl + "-MASS_SET_*-0.xml.gz")
        if len(xml_files) == 0:
            xml_files = glob.glob(results_dir + gl + "-MASS_SET_*_0_.xml.gz")
        print(f"Found {len(xml_files)} in grid_level {gl}")
        for gi, xml_file in enumerate(xml_files):
            xmldoc = utils.load_filename(
                xml_file, contenthandler=ligolw.LIGOLWContentHandler
            )
            new_tbl = lsctables.SnglInspiralTable.get_table(xmldoc)
            data_dict["filename"].append(xml_file)
            for row in new_tbl:
                data_dict["mass1"].append(row.mass1)
                data_dict["mass2"].append(row.mass2)
                data_dict["margll"].append(row.snr)
                data_dict["iteration_level"].append(i)
                data_dict["grid_id"].append(gi)
                try:
                    data_dict["spin1z"].append(row.spin1z)
                    data_dict["spin2z"].append(row.spin2z)
                    (chi_eff, chi_a,) = amrlib.transform_s1zs2z_chi_eff_chi_a(
                        row.mass1, row.mass2, row.spin1z, row.spin2z
                    )
                    data_dict["chi_eff"].append(chi_eff)
                    data_dict["chi_a"].append(chi_a)
                except AttributeError:
                    print("No spin information found in SnglInspiralTable")
    if data_dict["spin1z"] == []:
        del data_dict["spin1z"]
        del data_dict["spin2z"]
    data_dict = {key: np.array(data_dict[key]) for key in data_dict.keys()}
    return data_dict


def find_sigma(grid_data, param_list, sigma_factor, grid_level=None):
    """
    Find standard deviation of the gaussian at each grid point.
    Standand deviation at a given grid point is equal to half
    the separation between given grid point and its nearest
    neighbour grid point.
    """
    Sigma = {param: [] for param in param_list}
    for param in param_list:
        grid_param = np.array(grid_data[param])
        grid_iteration_level = grid_data["iteration_level"]
        grid_id = grid_data["grid_id"]
        if grid_level is not None:
            grid_inds = grid_id[grid_iteration_level == grid_level]
            grid_param = np.array(grid_data[param])[grid_inds]
        for j in range(len(grid_param)):
            distance_array = np.array(
                [
                    abs(grid_param[j] - grid_param[i])
                    for i in range(len(grid_param))
                ]
            )
            distance_array = np.sort(distance_array[distance_array > 1e-5])
            distance = distance_array[0]
            Sigma[param] = np.append(
                Sigma[param], sigma_factor[param] * distance
            )

    return Sigma


def plot_grid(
    grid_data, param1, param2, plot_dir, event_info=None, grid_level=None
):
    """
    plot grid alignment for param1 and param2 and a specific grid level.

    Valid grid_level = 0,1,2,3,....None

    Valid param1 and param2 = mass1, mass2, mchirp, eta, spin1z, spin2z,
                              mu1, mu2, q, tau0, tau3, mtotal

    grid_level=None plots the grid point from all grid levels


    """
    logging.info(
        f"plotting grids for {param1} and {param2} on grid_level={grid_level}"
    )
    Margll = grid_data["margll"]
    grid_iteration_level = grid_data["iteration_level"]
    grid_id = grid_data["grid_id"]
    if grid_level is not None:
        grid_inds = grid_id[grid_iteration_level == grid_level]
        data1 = grid_data[param1][grid_inds]
        data2 = grid_data[param2][grid_inds]
        weight = Margll[grid_inds]
    else:
        data1 = grid_data[param1]
        data2 = grid_data[param2]
        weight = Margll
    plt.figure()
    plt.scatter(
        data1,
        data2,
        c=weight,
        vmin=np.min(Margll),
        vmax=np.max(Margll),
    )
    if event_info is not None:
        plt.plot(event_info[param1], event_info[param2], "r*")
    plt.xlabel(f"{param1}_d")
    plt.ylabel(f"{param2}_d")
    plt.xlim(
        np.min(grid_data[param1]),
        np.max(grid_data[param1]),
    )
    plt.ylim(
        np.min(grid_data[param2]),
        np.max(grid_data[param2]),
    )
    if grid_level is not None:
        plt.title("grid_level = " + str(grid_level))
    else:
        plt.title("all grids")
    plt.colorbar(label=r"$log(L_{marg})$")
    if grid_level is not None:
        filename = (
            f"{plot_dir}/grid_{param1}"
            f"_{param2}_iteration-{str(grid_level)}.png"
        )
    else:
        filename = f"{plot_dir}/grid_{param1}_{param2}_all.png"
    plt.savefig(filename)
    return


def plot_posterior(
    sample_dict, param, plot_dir, event_info=None, grid_level=None
):
    print(f"plotting posterior for {param} at grid_level={grid_level}")
    samples = sample_dict[param]

    if param in ["mass1", "mass2"]:
        prior = None
    else:
        prior = sample_dict["prior"]
        prior = prior/sum(prior)
    fig, ax = plt.subplots()
    ax.hist(
        samples,
        bins=50,
        weights=prior,
        histtype="step",
        density=True,
        color="g",
    )
    if event_info is not None:
        ax.axvline(x=event_info[param], color="red")
    ax.set_xlabel(f"{param}_d")
    ax.set_ylabel("posterior")
    ax.yaxis.set_ticks([])
    if grid_level is not None:
        plt.title("grid_level = " + str(grid_level))
        filename = (
            f"{plot_dir}/posterior_detframe"
            f"{param}_iteration-{str(grid_level)}.png"
        )
    else:
        plt.title("all grids")
        filename = f"{plot_dir}/posterior_detframe_{param}_all.png"
    plt.savefig(filename)
    return


def plot_2d_posterior_with_grid(
    sample_dict,
    grid_data,
    distance_coordinates_str,
    plot_dir,
    grid_level=None,
    event_info=None,
):
    distance_coordinates = distance_coordinates_str.split("_")
    param1_name = distance_coordinates[0]
    param2_name = distance_coordinates[1]
    grid_iteration_level = grid_data["iteration_level"]
    grid_id = grid_data["grid_id"]
    if grid_level is not None:
        grid_inds = grid_id[grid_iteration_level == grid_level]
        data1 = grid_data[param1_name][grid_inds]
        data2 = grid_data[param2_name][grid_inds]
        weight = grid_data["margll"][grid_inds]
    else:
        data1 = grid_data[param1_name]
        data2 = grid_data[param2_name]
        weight = grid_data["margll"]
    all_weights = grid_data["margll"]
    plt.figure()
    plt.scatter(
        data1,
        data2,
        c=weight,
        vmin=np.min(all_weights),
        vmax=np.max(all_weights),
    )
    if event_info is not None:
        plt.plot(event_info[param1_name], event_info[param2_name], "r*")
    plt.xlabel(f"{param1_name}_d")
    plt.ylabel(f"{param2_name}_d")
    plt.colorbar(label=r"$ln(L_{marg})$")

    samples1 = sample_dict[param1_name]
    samples2 = sample_dict[param2_name]
    prior = sample_dict["prior"]
    plt.hist2d(samples1, samples2, bins=50, weights=prior, density=True)
    if grid_level is not None:
        plt.title("grid_level = ", str(grid_level))
        filename = (
            f"{plot_dir}/{param1_name}_{param2_name}"
            f" _iteration-{str(grid_level)}.png"
        )
    else:
        plt.title("all grids")
        filename = f"{plot_dir}/{param1_name}_{param2_name}_all.png"
    plt.savefig(filename)
    return


def save_m1m2_posterior_samples(sample_dict, save_dir):
    print("saving poserior samples for m1-m2")
    samples_mass1 = sample_dict["mass1"]
    samples_mass2 = sample_dict["mass2"]
    grid_id = sample_dict["grid_id"]

    filename = f"{save_dir}/intrinsic_posterior_samples_detframe.h5"
    f = h5py.File(filename, "w")
    f.create_dataset("mass1_d", data=samples_mass1)
    f.create_dataset("mass2_d", data=samples_mass2)
    f.create_dataset("grid_id", data=grid_id)
    f.close()
    return
