#!/usr/bin/env python3
# -----------------------------------------------------------------------------
#   Copyright (C): OpenGATE Collaboration
#   This software is distributed under the terms
#   of the GNU Lesser General  Public Licence (LGPL)
#   See LICENSE.md for further details
# -----------------------------------------------------------------------------

import numpy as np
import gatetools.phsp as phsp
import gatetools as gt
import click
import os
import logging

logger = logging.getLogger(__name__)

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('input_filename')
@click.option('-n', default=float(-1), help='Use -1 to read all data')
@click.option('--output', '-o', default='AUTO', help='If AUTO, use input_filename.npy')
@click.option('--rm_keys', '-r', help='Remove the given keys (as a str list such that "X Y Z")', default='')
@click.option('--keys', '-k', help='Set the kept keys (as a str list such that "X Y Z")', default='')
@click.option('--mod_key', '-m', type=(str,str), multiple=True, help='Modify the name of a key, for example: --mod_key X Y    will change the key X to Y')
@click.option('--overwrite', '-f', is_flag=True, default=False, help='Overwrite file if already exist')
@click.option('--shuffle', '-s', is_flag=True, default=False, help='Shuffle the (output) data')
@gt.add_options(gt.common_options)
def gt_phsp_convert(input_filename, output, keys, rm_keys, n, overwrite, shuffle, mod_key, **kwargs):
    """
    \b
    Convert to npy

    \b
    <INPUT_FILENAME> : input PHSP root file
    """

    # logger
    gt.logging_conf(**kwargs)

    if output == 'AUTO':
        b, extension = os.path.splitext(input_filename)
        output = b + '.npy'

    if not overwrite and os.path.isfile(output):
        logger.error('Error output file already exist:', output)
        exit(0)

    if os.path.isdir(output):
        logger.error('Error output already exist (its a dir):', output)
        exit(0)

    b, extension = os.path.splitext(output)
    if extension != '.npy':
        logger.error('Error output extension must by .npy:', output)
        exit(0)

    keys = phsp.str_keys_to_array_keys(keys)
    rm_keys = phsp.str_keys_to_array_keys(rm_keys)
    if len(keys) > 0 and len(rm_keys) > 0:
        logger.error('Cannot provide both --rm_keys and --keys')
        exit(0)

    # load data keys and the total nb of values (m) ; only n values are read
    data, read_keys, m = phsp.load(input_filename, treename='PhaseSpace', nmax=n)

    # remove or keep keys if needed
    if len(rm_keys) > 0:
        data, keys = phsp.remove_keys(data, read_keys, rm_keys)

    if len(keys) > 0:
        data = phsp.select_keys(data, read_keys, keys)
    else:
        keys = read_keys

    for mod in mod_key:
        i = keys.index(mod[0])
        keys[i] = mod[1]

    # shuffle
    if shuffle:
        np.random.shuffle(data)

    # write 
    phsp.save_npy(output, data, keys)


# --------------------------------------------------------------------------
if __name__ == '__main__':
    gt_phsp_convert()
