#!/bin/env python

import argparse
import distutils
import json
import os
import sys
from pathlib import Path
from typing import List

import bilby
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger

from pygwb.baseline import Baseline
from pygwb.detector import Interferometer
from pygwb.parameters import Parameters, ParametersHelp


def help_arguments(parent):
    ann = getattr(Parameters, "__annotations__", {})
    parser = argparse.ArgumentParser(parents=[parent])
    for name, dtype in ann.items():
        name_help = ParametersHelp[name].help
        if dtype == List:
            parser.add_argument(f"--{name}", help=name_help, type=str, nargs='+', required=False)
        else:
            parser.add_argument(f"--{name}", help=name_help, type=dtype, required=False)
    return parser

def main():
    params = Parameters()
    pipe_parser = argparse.ArgumentParser(add_help=False)
    pipe_parser.add_argument(
        "--param_file", help="Parameter file to use for analysis.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--output_path", help="Location to save output to.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--calc_pt_est", help="Calculate omega point estimate and sigma from data.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--apply_dsc", help="Apply delta sigma cut when calculating final output.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--pickle_out", help="Pickle output Baseline of the analysis.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--wipe_ifo", help="Wipe interferometer data to reduce size of pickled Baseline.", action="store", type=str, required=False
    )

    help_args = help_arguments(pipe_parser)
    help_args.parse_known_args()  # for help

    script_args, parameters_args = pipe_parser.parse_known_args()

    if script_args.param_file:
        params.update_from_file(script_args.param_file)
    params.update_from_arguments(parameters_args)
    logger.info(f"Successfully imported parameters from paramfile and input.")

    if script_args.output_path:
        output_path = Path(script_args.output_path)
        if not output_path.exists():
            output_path.mkdir(parents=True)
    else: 
        output_path = Path('')

    if script_args.param_file:
        outfile_path = f"{output_path}/{Path(script_args.param_file).stem}_final{Path(script_args.param_file).suffix}"
    else:
        outfile_path = Path(f"{output_path}/parameters_final.ini")
    params.save_paramfile(str(outfile_path))
    logger.info(f"Saved final param file at {outfile_path}.")

    if script_args.calc_pt_est:
        pipe_calculate_point_estimate = bool(distutils.util.strtobool(script_args.calc_pt_est))
    else:
        pipe_calculate_point_estimate = True
    
    if script_args.apply_dsc:
        pipe_apply_dsc = bool(distutils.util.strtobool(script_args.apply_dsc))
    else:
        pipe_apply_dsc = True

    if script_args.pickle_out:
        pipe_pickle_baseline = bool(distutils.util.strtobool(script_args.pickle_out))
    else:
        pipe_pickle_baseline = True

    if script_args.wipe_ifo:
        wipe_ifo = bool(distutils.util.strtobool(script_args.wipe_ifo))
    else:
        wipe_ifo = True

    ifo_1 = Interferometer.from_parameters(params.interferometer_list[0], params)
    ifo_2 = Interferometer.from_parameters(params.interferometer_list[1], params)
    logger.info(f"Loaded up interferometers with selected parameters.")

    if params.gate_data:
        logger.info(f"Applying gates.")
        gate_params = { 
                "gate_tzero":params.gate_tzero,
                "gate_tpad":params.gate_tpad,
                "gate_threshold":params.gate_threshold,
                "cluster_window":params.cluster_window,
                "gate_whiten":params.gate_whiten,
                }
        ifo_1.gate_data_apply(**gate_params)
        logger.info(f"Gates for IFO 1:{ifo_1.gates}")
        ifo_2.gate_data_apply(**gate_params)
        logger.info(f"Gates for IFO 2:{ifo_2.gates}")

    base_12 = Baseline.from_parameters(ifo_1, ifo_2, params)
    logger.info(
        f"Baseline with interferometers {ifo_1.name}, {ifo_2.name} initialised."
    )

    logger.info(f"Setting PSDs and CSDs of the baseline...")
    base_12.set_cross_and_power_spectral_density(params.frequency_resolution)
    base_12.set_average_power_spectral_densities()
    base_12.set_average_cross_spectral_density()

    logger.info(f"... done.")

    """
    check nothing's gone wrong in the frequency handling...
    """
    deltaF = base_12.frequencies[1] - base_12.frequencies[0]
    if abs(deltaF - params.frequency_resolution) > 1e-6:
        raise ValueError("Frequency resolution in PSD/CSD is different than requested.")

    base_12.calculate_delta_sigma_cut(
        delta_sigma_cut = params.delta_sigma_cut,
        alphas = params.alphas_delta_sigma_cut,
        fref = params.fref,
        flow=params.flow,
        fhigh=params.fhigh,
        return_naive_and_averaged_sigmas = True
    )

    logger.info(
        f"times flagged by the delta sigma cut as badGPStimes:{base_12.badGPStimes}"
    )

    if pipe_calculate_point_estimate:
        logger.info(f"calculating the point estimate and sigma...")

        base_12.set_point_estimate_sigma(
            alpha=params.alpha,
            fref=params.fref,
            flow=params.flow,
            fhigh=params.fhigh,
            apply_dsc=pipe_apply_dsc
        )

        logger.success(
            f"Ran stochastic search over times {int(params.t0)}-{int(params.tf)}"
        )
        logger.success(f"\tPOINT ESTIMATE: {base_12.point_estimate:e}")
        logger.success(f"\tSIGMA: {base_12.sigma:e}")

        data_file_name = f"{output_path}/point_estimate_sigma_{int(params.t0)}-{int(params.tf)}"

        logger.info(
            "Saving point_estimate and sigma spectrograms, spectra, and final values to file."
        )
        logger.info("Saving average psds and csd to file.")
        base_12.save_point_estimate_spectra(
            params.save_data_type,
            data_file_name,
        )
        csd_file_name = f"{output_path}/csds_psds_{int(params.t0)}-{int(params.tf)}"
        base_12.save_psds_csds(
            params.save_data_type,
            csd_file_name,
        )
        if pipe_pickle_baseline:
            logger.info("Pickling the baseline.")
            pickle_name = f"{output_path}/{base_12.name}_{int(params.t0)}-{int(params.tf)}.pickle"
            base_12.save_to_pickle(pickle_name, wipe=wipe_ifo)

    else:
        logger.info("Saving average psds and csd to file.")

        data_file_name = f"psds_csds_{int(params.t0)}-{int(params.tf)}"

        base_12.npz_save_psds_csds(
            data_file_name,
            base_12.frequencies,
            base_12.csd,
            base_12.interferometer_1.average_psd,
            base_12.interferometer_2.average_psd,
        )

if __name__ == "__main__":
    main()
