#!/bin/env python

import argparse
import copy
import os
from os import listdir
from os.path import isfile, join
from pathlib import Path

import numpy as np
from loguru import logger
from tqdm import tqdm

from pygwb.baseline import Baseline
from pygwb.constants import h0 as pygwb_h0
from pygwb.omega_spectra import OmegaSpectrum, reweight_spectral_object
from pygwb.parameters import Parameters
from pygwb.postprocessing import (
    calc_Y_sigma_from_Yf_sigmaf,
    combine_spectra_with_sigma_weights,
)


def sorting_function(item):
    res = item.partition("point_estimate_sigma_")[-1]
    return np.float64(res.partition("-")[0])


"""
SCRIPT TO COMBINE PYGWB_PIPE RUN OUTPUTS.
currently only works with npz - this will be updated for compatibility with other formats as wego along.
"""


def main():
    combine_parser = argparse.ArgumentParser()
    combine_parser.add_argument(
        "--data_path", help="Path to data files folder.", action="store", type=Path
    )
    combine_parser.add_argument(
        "--alpha",
        help="Spectral index alpha to use for spectral re-weighting.",
        action="store",
        type=str,
    )
    combine_parser.add_argument(
        "--fref",
        help="Reference frequency to use when presenting results.",
        action="store",
        type=int,
    )
    combine_parser.add_argument(
        "--param_file", help="Parameter file", action="store", type=str
    )
    combine_parser.add_argument(
        "--h0",
        help="Value of h0 to use. Default is pygwb.constants.h0.",
        action="store",
        type=float,
        required=False,
    )
    combine_parser.add_argument(
        "--out_path", help="Output path.", action="store", type=Path, required=False
    )
    combine_args = combine_parser.parse_args()
    if not combine_args.h0:
        combine_args.h0 = pygwb_h0
    if not combine_args.out_path:
        combine_args.out_path = Path("./")
    combine_args.alpha = np.float(eval(combine_args.alpha))
    params = Parameters()
    params.update_from_file(combine_args.param_file)

    files_ptest = [
        os.path.join(combine_args.data_path, f)
        for f in listdir(combine_args.data_path)
        if isfile(join(combine_args.data_path, f)) and f.startswith("point")
    ]

    files_ptest.sort(key=sorting_function)
    times = [int(sorting_function(files_ptest[idx])) for idx in range(len(files_ptest))]
    frequencies = np.load(files_ptest[0])["frequencies"]
    #frequency_mask = np.load(files_ptest[0])["frequency_mask"]
    frequency_mask = True

    # spectral objects
    Y_j = []
    sigma_j = []
    Y_spectra_j = []
    sigma_spectra_j = []
    Y_seg = []
    sigma_seg = []
    # DQ objects
    naive_sigmas_j = []
    slide_sigmas_j = []
    delta_sigmas_j = []
    badGPStimes_j = []
    times_j = []

    pt_est_sigma_unweighted_path = os.path.join(
        combine_args.out_path,
        f"point_estimate_sigma_{times[0]}-{times[-1]}_UNWEIGHTED.npz",
    )
    delta_sigma_cut_output_path = os.path.join(
        combine_args.out_path,
        f"delta_sigma_cut_{times[0]}-{times[-1]}.npz",
    )
    pt_est_sigma_spectra_path = os.path.join(
        combine_args.out_path,
        "point_estimate_sigma_spectra_alpha_{:.1f}".format(combine_args.alpha)
        + f"_fref_{combine_args.fref}_{times[0]}-{times[-1]}.npz",
    )

    logger.info('Unpacking files...')
    for file in tqdm(files_ptest):
        data_file = np.load(file, allow_pickle=True)
        Y_j.append(data_file["point_estimate"])
        sigma_j.append(data_file["sigma"])
        Y_spectra_j.append(data_file["point_estimate_spectrum"])
        sigma_spectra_j.append(data_file["sigma_spectrum"])
        naive_sigmas_j.append(data_file["naive_sigma_values"].T)
        slide_sigmas_j.append(data_file["slide_sigma_values"].T)
        delta_sigmas_j.append(data_file["delta_sigma_values"].T)
        times_j.append(data_file["delta_sigma_times"])
        badGPStimes_j.append(data_file["badGPStimes"])

        Y_s, sigma_s = calc_Y_sigma_from_Yf_sigmaf(
            data_file["point_estimate_spectrogram"], data_file["sigma_spectrogram"], frequency_mask=frequency_mask
        )
        Y_seg.append(Y_s)
        sigma_seg.append(sigma_s)

    Y_seg = np.concatenate(Y_seg)
    sigma_seg = np.concatenate(sigma_seg)
    Y_j = np.array(Y_j)
    sigma_j = np.array(sigma_j)
    np.savez(pt_est_sigma_unweighted_path, point_estimate=Y_j, sigma=sigma_j, point_estimate_per_seg=Y_seg, sigma_per_seg=sigma_seg)
    logger.info(
        f"saved file with unweighted point estimate and sigma values for all times in run:\n {pt_est_sigma_unweighted_path}."
    )

    naive_sigmas_j = np.concatenate(naive_sigmas_j)
    slide_sigmas_j = np.concatenate(slide_sigmas_j)
    delta_sigmas_j = np.concatenate(delta_sigmas_j)
    times_j = np.concatenate(times_j)
    badGPStimes_j = np.concatenate(badGPStimes_j)
    np.savez(
        delta_sigma_cut_output_path,
        naive_sigmas=naive_sigmas_j,
        slide_sigmas=slide_sigmas_j,
        delta_sigmas=delta_sigmas_j,
        badGPStimes=badGPStimes_j,
        times=times_j,
    )
    logger.info(
        f"saved file with all sigma information related to the delta sigma cut for all times in run:\n {pt_est_sigma_unweighted_path}."
    )

    Y_spectrum_combined, sigma_spectrum_combined = combine_spectra_with_sigma_weights(
        np.array(Y_spectra_j), np.array(sigma_spectra_j)
    )

    Y_spectrum = OmegaSpectrum(
        Y_spectrum_combined,
        alpha=params.alpha,
        fref=params.fref,
        h0=pygwb_h0,
        name="Y_spectrum",
        frequencies=frequencies,
    )
    sigma_spectrum = OmegaSpectrum(
        sigma_spectrum_combined,
        alpha=params.alpha,
        fref=params.fref,
        h0=pygwb_h0,
        name="sigma_spectrum",
        frequencies=frequencies,
    )

    Y_spectrum.write(
        os.path.join(combine_args.out_path, f"Y_spectrum_{times[0]}-{times[-1]}.hdf5")
    )
    sigma_spectrum.write(
        os.path.join(
            combine_args.out_path, f"sigma_spectrum_{times[0]}-{times[-1]}.hdf5"
        )
    )
    logger.info(
        f"Saved file with combined point estimate and sigma OmegaSpectrum objects for this run. These are weighted with alpha={params.alpha}"
    )

    Y_estimate, sigma_estimate = calc_Y_sigma_from_Yf_sigmaf(
        Y_spectrum,
        sigma_spectrum,
        frequency_mask=frequency_mask,
        alpha=combine_args.alpha,
        fref=combine_args.fref,
    )
    Y_estimate *= (Y_spectrum.h0 / combine_args.h0) ** 2
    sigma_estimate *= (sigma_spectrum.h0 / combine_args.h0) ** 2

    logger.info(
        "Final point estimate re-weighted with alpha={:.2f}".format(combine_args.alpha)
        + f" at reference frequency fref={combine_args.fref} with h0={combine_args.h0}:\n [{Y_estimate} +/- {sigma_estimate}]"
    )

    Y_reweight_spectrum = OmegaSpectrum(
        Y_spectrum_combined,
        alpha=params.alpha,
        fref=params.fref,
        h0=pygwb_h0,
        name="Y_spectrum",
        frequencies=frequencies,
    )
    Y_reweight_spectrum.reweight(
        new_alpha=combine_args.alpha, new_fref=combine_args.fref
    )
    sigma_reweight_spectrum = OmegaSpectrum(
        sigma_spectrum_combined,
        alpha=params.alpha,
        fref=params.fref,
        h0=pygwb_h0,
        name="sigma_spectrum",
        frequencies=frequencies,
    )
    sigma_reweight_spectrum.reweight(
        new_alpha=combine_args.alpha, new_fref=combine_args.fref
    )
    Y_reweight_spectrum.reset_h0(new_h0=combine_args.h0)
    sigma_reweight_spectrum.reset_h0(new_h0=combine_args.h0)

    np.savez(
        pt_est_sigma_spectra_path,
        point_estimate=Y_estimate,
        sigma_estimate=sigma_estimate,
        point_estimate_spectrum=Y_reweight_spectrum.value,
        sigma_spectrum=sigma_reweight_spectrum.value,
        point_estimates_seg_UW=Y_seg,
        sigmas_seg_UW=sigma_seg,
    )
    logger.info(
        f"Saved file with re-weighted point estimate and sigma values and spectra:\n {pt_est_sigma_spectra_path}."
    )
    exit()


if __name__ == "__main__":
    main()
