""" Functions for calibrating the optimal collision energy setting
    for your experiment.
"""
import pandas as pd

from inspire.constants import (
    ACCESSION_STRATUM_KEY,
    BOLD_TEXT,
    ENDC_TEXT,
    ENGINE_SCORE_KEY,
    KNOWN_PTM_WEIGHTS,
    LABEL_KEY,
    OKBLUE_TEXT,
    OKCYAN_TEXT,
    PTM_ID_KEY,
    PTM_NAME_KEY,
    PTM_SEQ_KEY,
    SCAN_KEY,
    SOURCE_KEY,
    SPECTRAL_ANGLE_KEY,
    UNDERLINE_TEXT,
)
from inspire.input.mgf import process_mgf_file
from inspire.input.msp import msp_to_df
from inspire.input.mzml import process_mzml_file
from inspire.input.search_results import generic_read_df
from inspire.feature_creation import combine_spectral_data
from inspire.predict_spectra import predict_spectra
from inspire.prepare import write_prosit_input_df
from inspire.spectral_features import calculate_spectral_features
from inspire.utils import (
    check_bad_mods,
    get_ox_flag,
    get_cam_flag,
    remove_source_suffixes,
)

COLLISION_ENERGY_RANGE = list(range(20, 41))

def _get_top_hits(config):
    """ Function to extract the top scoring hits for collision energy calibration.

    Parameters
    ----------
    config : inspire.config.Config
        The Config object for the experiment.
    """
    target_df, mods_df = generic_read_df(config, save_dfs=False, overwrite_reduce=True)

    unknown_mods = mods_df[
        (mods_df[PTM_NAME_KEY] != 'Oxidation (M)') &
        ((mods_df[PTM_NAME_KEY] != 'Carbamidomethyl (C)'))
    ][PTM_ID_KEY].tolist()
    unknown_mods = {str(x) for x in unknown_mods}

    target_df['unknownModifications'] = target_df[PTM_SEQ_KEY].apply(
        lambda x : check_bad_mods(x, unknown_mods)
    )
    target_df = target_df[~target_df['unknownModifications']]

    if ACCESSION_STRATUM_KEY in target_df.columns:
        target_df = target_df[target_df[ACCESSION_STRATUM_KEY] == 0]

    top_5_pct_cut = target_df[ENGINE_SCORE_KEY].quantile(0.95)

    target_df = target_df[
        (target_df[ENGINE_SCORE_KEY] > top_5_pct_cut) &
        (target_df[LABEL_KEY] == 1) #move this up
    ]

    if target_df.shape[0] > 1000:
        target_df = target_df.nlargest(1000, columns=[ENGINE_SCORE_KEY])

    return target_df, mods_df

def prepare_calibration(config):
    """ Function to generate Prosit input for collision energy calibration.

    Parameters
    ----------
    config : inspire.config.Config
        The Config object for the experiment.
    """
    target_df, mods_df = _get_top_hits(config)

    for idx, collision_energy in enumerate(COLLISION_ENERGY_RANGE):
        write_prosit_input_df(
            target_df,
            mods_df,
            config,
            collision_energy,
            'calibrationInput',
            overwrite=idx==0,
        )

    return target_df, mods_df

def calibrate(config):
    """ Function to calibrate the optimal collision energy for Prosit input.

    Parameters
    ----------
    config : inspire.config.Config
        The Config object for the experiment.
    """
    print(
        OKCYAN_TEXT +
        '\tSelecting top hits...' +
        ENDC_TEXT
    )
    target_df, mods_df = prepare_calibration(config)
    predict_spectra(config, 'calibrate')

    prosit_df = msp_to_df(
        f'{config.output_folder}/calibrationPredictions.msp', 'prosit', None,
    )

    ox_flag = get_ox_flag(mods_df)
    cam_flag = get_cam_flag(mods_df)

    mods_dict = {
        0: 0.0,
        int(cam_flag): KNOWN_PTM_WEIGHTS['Carbamidomethyl (C)'],
        int(ox_flag): KNOWN_PTM_WEIGHTS['Oxidation (M)']
    }

    if config.combined_scans_file is not None:
        scan_files = [remove_source_suffixes(config.combined_scans_file)]
    else:
        scan_files = target_df[SOURCE_KEY].unique().tolist()

    scan_dfs = []
    for scan_file in scan_files:
        if config.combined_scans_file is not None:
            filtered_search_df = target_df
        else:
            filtered_search_df = target_df[target_df[SOURCE_KEY] == scan_file]

        scans = filtered_search_df[SCAN_KEY].unique()
        if config.scans_format == 'mzML':
            scan_df = process_mzml_file(
                f'{config.scans_folder}/{scan_file}.{config.scans_format}',
                set(scans.tolist()),
            )
        else:
            scan_df = process_mgf_file(
                f'{config.scans_folder}/{scan_file}.{config.scans_format}',
                set(scans.tolist()),
                config.scan_title_format,
                config.source_files,
                combined_source_file=config.combined_scans_file is not None,
            )
        scan_dfs.append(scan_df.drop_duplicates(subset=[SOURCE_KEY, SCAN_KEY]))

    combined_scan_df = pd.concat(scan_dfs)
    print(
        OKCYAN_TEXT +
        '\t\tCombining all spectral data...' +
        ENDC_TEXT
    )
    combined_df = combine_spectral_data(
        filtered_search_df,
        combined_scan_df,
        prosit_df,
        ox_flag,
        'prosit',
    )
    print(
        OKCYAN_TEXT +
        '\t\tCalculating spectral angles...' +
        ENDC_TEXT
    )
    combined_df = combined_df.apply(
        lambda x : calculate_spectral_features(
            x,
            mods_dict,
            config.mz_accuracy,
            config.mz_units,
            None,
            '1',
            config.delta_method,
            config.spectral_predictor,
            minimal_features=True,
        ),
        axis=1
    )
    results_df = combined_df.groupby('collisionEnergy', as_index=False)[
        SPECTRAL_ANGLE_KEY
    ].mean()
    results_df.columns = ['collisionEnergy', SPECTRAL_ANGLE_KEY]

    optimal_collision_energy = results_df['collisionEnergy'].iloc[
        results_df[SPECTRAL_ANGLE_KEY].idxmax()
    ]
    print(
        OKBLUE_TEXT + BOLD_TEXT + UNDERLINE_TEXT +
        f'\n\t---> Optimal Collision Energy Setting: {optimal_collision_energy} <---\n' +
        ENDC_TEXT
    )
    results_df.to_csv(
        f'{config.output_folder}/collisionEnergyStats.csv',
        index=False
    )
