#!/usr/bin/env python3
# -----------------------------------------------------------------------------
#   Copyright (C): OpenGATE Collaboration
#   This software is distributed under the terms
#   of the GNU Lesser General  Public Licence (LGPL)
#   See LICENSE.md for further details
# -----------------------------------------------------------------------------

import sys
import numpy as np
import gatetools.phsp as phsp
import gatetools as gt
import click
import os
from matplotlib import pyplot as plt
import logging
logger=logging.getLogger(__name__)

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('filenames', nargs=-1)
@click.option('-n', default=float(1e5), help='Use -1 to read all data')
@click.option('--keys', '-k', help='Plot the given keys (as a str list such that "X Y Z")', default='')
@click.option('--quantile', '-q', default=float(0), help='Restrict histogram to quantile')
@click.option('--nb_bins', '-b', default=int(100), help='Number of bins')
@click.option('--shuffle', '-s', default=False, is_flag=True, help='shuffle samples when loading')
@click.option('--plot2d',
              type=(str, str),
              help='Add 2D plots (key1,key2), such as --plot2d X Ekine --plot2d X Y ', multiple=True)
@gt.add_options(gt.common_options)
def gt_phsp_plot(filenames, keys, n, quantile, nb_bins, plot2d, shuffle, **kwargs):
    '''
    \b
    Plot histograms

    \b
    <INPUT_FILENAME> : input PHSP root/pny files

    WARNING: if several filenames, they must have the same keys
    '''

    # logger
    gt.logging_conf(**kwargs)

    f = None
    nb_fig = 0
    q = {}

    for filename in filenames:
        logger.info(filename)

        # load data
        data, read_keys, m = phsp.load(filename, n, shuffle)

        # get keys
        ckeys = phsp.str_keys_to_array_keys(keys)
        if len(ckeys) == 0:
            ckeys = read_keys

        # augment data if needed
        if 'angleXY' in keys:
            data, read_keys = phsp.add_angle(data, read_keys, 'X', 'Y')

        # 2D figs
        keys_2D = plot2d;
        if keys_2D == None:
            keys_2D = []

        # compute fig row/col
        nb_fig = len(ckeys)+len(keys_2D)
        nrow, ncol = phsp.fig_get_nb_row_col(nb_fig)

        # create fig
        if not f:
            f, ax = plt.subplots(nrow, ncol, figsize=(25,10))

        # loop
        i = 0
        for k in ckeys:
            a = phsp.fig_get_sub_fig(ax,i)
            index = read_keys.index(k)
            x = data[:,index]
            q1 = quantile
            q2 = 1.0-quantile
            if filename == filenames[0]:
                q[k] = (np.quantile(x, q1), np.quantile(x, q2))
            label = ' {} $\mu$={:.2f} $\sigma$={:.2f}'.format(k, np.mean(x), np.std(x))
            a.hist(x, nb_bins,
                   #density=True,
                   histtype='stepfilled',
                   range=q[k],
                   # facecolor='g',
                   alpha=0.5,
                   label=label)
            #a.set_ylabel('Probability')
            a.set_ylabel('Counts')
            a.legend()
            i = i+1

        # 2D
        for k in keys_2D:
            a = phsp.fig_get_sub_fig(ax,i)
            phsp.fig_histo2D(a, data, read_keys, k, nb_bins, 'g')
            i = i+1

    if nb_fig==0:
        return

    # remove empty plot
    phsp.fig_rm_empty_plot(nb_fig, ax)

    if n == -1:
        n = m
    # plt.subplots_adjust(top=0.7)
    plt.suptitle('{} Values: {}/{}'.format(filenames, n, m))
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# --------------------------------------------------------------------------
if __name__ == '__main__':
    gt_phsp_plot()
