#! /usr/bin/env python3
"""
Create a plot for an energy spectrum for a certain peaking
time obtained with the CaenN6725 digitizer
"""
import matplotlib
matplotlib.use('Agg')

import uproot as up
import dashi as d
import numpy as np
import hepbasestack as hep
import pylab as p
import HErmes as he
import hepbasestack.layout as lo
import concurrent.futures as fut
import os
import os.path
import re
import time

import tqdm

from dactylos.analysis.waveform_analysis import WaveformAnalysis
from dactylos.analysis.noisemodel import noise_model, fit_noisemodel, extract_parameters_from_noisemodel   
from dactylos.analysis.spectrum_fits import fit_file
from dactylos.analysis.utils import get_stripname, plot_waveform 
# call imagemapgick with subprocess
# FIXME: there are python bindings

import shlex
import subprocess


from dataclasses import dataclass
from copy import deepcopy as copy

@dataclass
class PeakingTimeRun:
    ptime    : int = -1
    channel  : int = -1
    detid    : int = -1
    fwhm     : float = -1
    fwhm_err : float = -1 
    figname  : matplotlib.figure.Figure = p.figure()   


hep.visual.set_style_present()
d.visual()

logger = hep.logger.get_logger(20)

p.ioff()

########################################################################



if __name__ == '__main__':

    import sys
    import argparse
    from glob import glob

    parser = argparse.ArgumentParser(description='Analyze data taken by N6725 digitizer.\n\
                                     \t\tThis script can do several things:\n\
                                     \t\t1) Analyze the histograms as saved by the MC2 windows\n\
                                     \t\t   software in textfiles (one line per bin)\n\
                                     \t\t2) Analyze dactylos data - root files with the\n\
                                     \t\    tenergy from the digitizer internal shaping algorithm\n\
                                     \t\t3) Analyze the waveform data taken by dactylos - root files\n\
                                     \t\t   with waveform entries')
    parser.add_argument('--mc2',
                        dest='mc2',
                        default=False, action='store_true',\
                        help='Analyze data taken with mc2/windows.\
                              Currently, this relies on a certain file structure.')
    parser.add_argument('--trapezoid',
                        dest='trapezoid',
                        default=False, action='store_true',\
                        help='Use trapezoid for shaping instead of\
                              the Gaussian.'
                       )
    #parser.add_argument('infiles', metavar='infiles', type=str, nargs='+',
    #                help='input root files, taken with CraneLab')
    #parser.add_argument('--debug', dest='debug',
    #                    default=False, action='store_true',
    #                    help='Set the loglevel to 10, that is debug')
    parser.add_argument('-i','--indir', dest='indir',
                        default=".",type=str,
                        help='glob all rootfiles in this directory for input"')
    parser.add_argument('-o','--outdir', dest='outdir',
                        default="firxraydata-outdir",type=str,
                        help='save plots in a directory named "outdir"')
    parser.add_argument('-j','--jobs', dest='njobs',
                        default=12,type=int,
                        help='number of jobs to use')
    parser.add_argument('--gaussian-shaping-times', nargs='+',\
                        help='Specify a sequence of peking times for the Gaussian\
                              shaper. If not given, a reasonable default is used',\
                        dest='gaussian_shaping_times',\
                        default=[],required=False)
    parser.add_argument('--channels', nargs='+',\
                        help='The sequence of channels to analyze. Default to 0,1,2,3,4,5,6,7',\
                        dest='channels',\
                        default=[0,1,2,3,4,5,6,7],required=False)
    parser.add_argument('--shaper-order', dest='shaper_order',
                        default=4,type=int,
                        help='order parameter of the Gaussian shaper. This parameter only has an effect together with --waveform-analysis.')
    parser.add_argument('--shaper-different-orders',
                        dest='shaper_different_orders',
                        default=False, action='store_true',\
                        help='Change the order for the shaper from 4 to 7.\
                              It is 4 for pt < 5ms and 7 otherwise.')
    parser.add_argument('--waveform-analysis', dest='waveform_analysis',
                        default=False, action='store_true',
                        help='Use the waveform information to generate the noisemodel plot.')
    parser.add_argument('--plot-waveforms', dest='plot_waveforms',
                        default=False, action='store_true',
                        help='plot the individual waveforms as png')
    #parser.add_argument('--digitizer-energy-only', dest='digitizer_energy_only',
    #                    default=False, action='store_true',
    #                    help='Only produce the plot of the energy as calculated by the digitizer')
    #parser.add_argument('--run-sequence', dest='run_sequence',
    #                    default=False, action='store_true',
    #                    help='Only get the digitizer energy from the series of given files and produce the resolution plot. This is meant to be used for one file per shaping time')

    args = parser.parse_args()

    # create the outdir
    try:
        os.makedirs(args.outdir)
    except Exception as e:
        logger.warning(f'Can not create directory {args.outdir}! {e}')

    if args.njobs > 1:
        executor = fut.ProcessPoolExecutor(max_workers=14)

    if args.mc2:
        infiles = []
        subfolders = glob(os.path.join(args.indir, '*'))
        for subfolder in subfolders:
            infiles += glob(os.path.join(subfolder, '*.txt'))

        regex = 'pt_(?P<stime>[0-9]*)_(mus|mu_s|ns)\/Sh(?P<detid>[0-9]*)_(?P<strip>[A-H]).txt' 
        pattern = re.compile(regex)
    else:
        infiles = glob(os.path.join(args.indir, '*.root'))
        regex = '(?P<detid>[0-9]*)_(s|p)time(?P<stime>[0-9]*).root' 
        pattern = re.compile(regex)


    if args.njobs > 1: future_to_ptime = dict()
    
    # store the results
    datasets = []

    # channels to analyze
    active_channels = [int(k) for k in args.channels]

    # mapping detector strips <-> digitizer channels
    ch_name_to_int = {'A' : 0, 'B' : 1, 'C' : 2, 'D' : 3,\
                      'E' : 4, 'F' : 5, 'G' : 6, 'H' :7}

    if args.plot_waveforms and not args.waveform_analysis:
        raise ValueError("Can not plot waveforms without the --waveform-analysis switch!") 
    if args.waveform_analysis:
        logger.info("Will do anlysis directly on the waveforms")
        if args.trapezoid:
            logger.warning("Will use trapezoid for shaping instead of Gaussian - experimental feature!")
        analysis = WaveformAnalysis(njobs=args.njobs,\
                                    active_channels=active_channels,\
                                    use_simple_trapezoid_shaper=args.trapezoid,\
                                    shaper_order=args.shaper_order,
                                    adjust_shaper_order_dynamically=args.shaper_different_orders)
        analysis.set_infiles(infiles)
        for infile in infiles:
            metadata = pattern.search(infile).groupdict()
            # FIXME: this is buggy in case
            # mutliple detectors are mixed
            ptime = int(metadata['stime']) # shaping/peaking time
            detid = int(metadata['detid']) # detector id

        # in case of the waveform analysis, the actual analysis 
        # is performed here, since we first needed to just
        # get the filenames 
        analysis.read_waveforms()
        nevents = analysis.get_eventcounts()
        recordlengths = analysis.get_recordlengths()
        logger.info(f'Seeing recordlenghts of {recordlengths}')
        logger.info(f'Seeing nevents of {nevents}')
        for k in analysis.active_channels:
            logger.info(f" Record length in ch {k} ; {len(analysis.channel_data[k][0])}")
        
        #peaking_sequence = np.array([500,1000,2000,4000,8000,12000,16000,20000,30000])    
        if args.gaussian_shaping_times:
            peaking_sequence = np.array([int(k) for k in args.gaussian_shaping_times])
        else:
            peaking_sequence = np.array([500,759,1000,2000,3000,4000,5000, 7000,8000,10000,12000,14000,16000,18000,20000,25000,30000])    
         
        logger.info(f'Setting peaking sequence {peaking_sequence} nano seconds')
        analysis.set_peakingtime_sequence(peaking_sequence)

        for ch in analysis.active_channels:
            if args.plot_waveforms:
                analysis.plot_waveforms(ch, 10)
            channel_peakingtimes = analysis.analyze(ch)

            logger.info('Shaping completed!')
            logger.info('Submitting fitting jobs...')

            for ptime in channel_peakingtimes.keys():
                kwargs = {'energy'     : channel_peakingtimes[ptime],\
                          'channel'    : copy(ch),\
                          'plot_dir'   : copy(args.outdir),\
                          'ptime'      : copy(ptime),\
                          'detid'      : copy(detid),\
                          'file_type'  : None
                }
            
                if args.njobs > 1:
                    future_to_ptime[executor.submit(fit_file, **kwargs)] = ptime 
                else:
                    detid, ch, params, params_err, figname = fit_file(**kwargs)
                    data = PeakingTimeRun(detid = detid,\
                                          channel = ch,\
                                          fwhm = params[1],\
                                          fwhm_err = params_err['fwhm0'],\
                                          ptime = ptime,\
                                          figname = figname)
                    datasets.append(data)
    

    # in case we do the analysis on the 
    # energies as calculated by the digitizer
    # we will have one root file per shaping time
    # this type of analysis is performed here
    # there is one if-branch for linux and 
    # another one for mc2/windows
    for infile in infiles:
        if args.waveform_analysis:
            # we do not need this step here and can 
            # skip it
            break

        elif args.mc2:
            logger.info(f"Loading {infile}")
            metadata = pattern.search(infile).groupdict()
            ptime = int(metadata['stime'])
            if ptime < 500:
                ptime *= 1000 # converti to ns
            detid = int(metadata['detid'])
            strip = metadata['strip'] 
            
            #for ch in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']:
            ch = ch_name_to_int[strip]
            kwargs = {'infilename' : copy(infile),\
                      'channel'    : copy(ch),\
                      'plot_dir'   : copy(args.outdir),\
                      'ptime'      : copy(ptime),\
                      'detid'      : copy(detid),\
                      'file_type'  : '.txt'
            }
            
            if args.njobs > 1:
                future_to_ptime[executor.submit(fit_file, **kwargs)] = ptime 
            else:
                detid, ch, params, params_err, figname = fit_file(**kwargs)
                data = PeakingTimeRun(detid = detid,\
                                      channel = ch,\
                                      fwhm = params[1],\
                                      fwhm_err = params_err['fwhm0'],\
                                      ptime = ptime,\
                                      figname = figname)
                datasets.append(data)
        else:  
            # last option will be anlyzing data taken 
            # with dactylos (root files) one file 
            # per shaping time
            logger.info(f"Loading {infile}")

            metadata = pattern.search(infile).groupdict()
            ptime = int(metadata['stime']) # shaping/peaking time
            detid = int(metadata['detid']) # detector id
       
            # anlayze per channel, each file has all
            # the 8 channels saved 
            for ch in active_channels:
                    
                # be paranoid with mutable arguments
                file_type = '.root'
                kwargs = {'infilename' : copy(infile),\
                          'channel'    : copy(ch),\
                          'plot_dir'   : copy(args.outdir),\
                          'ptime'      : copy(ptime),\
                          'detid'      : copy(detid),\
                          'file_type'  : copy(file_type)
                }
                
                if args.njobs > 1:
                    future_to_ptime[executor.submit(fit_file, **kwargs)] = ptime 
                else:
                    detid, ch, params, params_err, figname = fit_file(**kwargs)
                    data = PeakingTimeRun(detid = detid,\
                                          channel = ch,\
                                          fwhm = params[1],\
                                          fwhm_err = params_err['fwhm0'],\
                                          ptime = ptime,\
                                          figname = figname)
                    datasets.append(data)


    # get the results in the parallel case
    if args.njobs > 1:
        for future in tqdm.tqdm(fut.as_completed(future_to_ptime)):
            ptime = future_to_ptime[future]
            try:
                detid, ch, params, params_err, figname = future.result()
            except Exception as e:
                logger.warning(f'Can not fit file for ptime {ptime}, exception {e}')
                continue
            data = PeakingTimeRun(detid = detid,\
                                  channel = ch,\
                                  fwhm = params[1],\
                                  fwhm_err = params_err['fwhm0'],\
                                  ptime = ptime,\
                                  figname = figname)
            datasets.append(data)
        logger.info(f'In total, {len(datasets)} datasets were processed ')

    # prepare all the plots we want to combine in 
    # one for a single peaking time
    pktimes_savenames = dict([(dsptime.ptime,[]) for dsptime in datasets])
    # create the model fit plots
    ch_savenames = dict(zip([ch for ch in range(8)],[[] for ch in range(8)])) 
    for ch in active_channels:
        stripname = get_stripname(ch)
        ch_datasets = [data for data in datasets if data.channel == ch]
        logger.info(f'{len(ch_datasets)} datasets available for channel {ch}')
        xs = np.array([data.ptime for data in ch_datasets])
        ys = np.array([data.fwhm for data in ch_datasets])
        ys_err = np.array([data.fwhm_err for data in ch_datasets])

        figures = [data.figname for data in ch_datasets]
        
        # sort them from smallest to largest peaking time
        idx     = np.argsort(np.array(xs))
        xs      = np.array(xs)[idx]
        ys      = np.array(ys)[idx]
        ys_err  = np.array(ys_err)[idx]

        try:
            # do not use minuit right now until we have understood the errors
            fit_noisemodel(xs, ys, ys_err, ch, detid, use_minuit=True, plotdir=args.outdir)
        except Exception as e:
            logger.warning(f"Can not fit noisemodel for {detid} ch {ch}, exception {e}")
            continue

        # save the individual figure
        for i,fig in tqdm.tqdm(enumerate(figures), desc=f'Saving data for {get_stripname(ch)}', total=len(figures)):
            #pngfilename = f'det{detid}-stime{xs[i]}-{stripname}.png' 
            #logger.info(f'Saving {pngfilename} ... ')
            #fig.savefig(os.path.join(args.outdir,pngfilename))
            #p.close(fig)
            pngfilename = fig
            ch_savenames[ch].append(os.path.join(args.outdir, pngfilename))
            pktimes_savenames[xs[i]].append(os.path.join(args.outdir, pngfilename))
        
    # combine all the figures for a specific ch in 
    # one panel with imagemagic
    for ch in ch_savenames.keys():
        mplotname = os.path.join(args.outdir,f'det{detid}-{get_stripname(ch)}')
        command = f'montage -geometry +4+4 -tile 5x2 '
        for gfx in ch_savenames[ch]:
            command += f' {gfx} '
        command += f' {mplotname}'  

        logger.info(f'Executing {command}')
        proc = subprocess.Popen(shlex.split(command + '.pdf')).communicate()
        proc = subprocess.Popen(shlex.split(command + '.png')).communicate()
        proc = subprocess.Popen(shlex.split(command + '.jpg')).communicate()
       
        # shrink the jpg so it will go in google docs presentation
        command_shrink = f'convert {mplotname}.jpg -resize 50% -normalize  {mplotname}.jpg'
        proc = subprocess.Popen(shlex.split(command_shrink)).communicate()
    # also combine all the figures for a specific peaking time in 
    # one panel with imagemagic
    for pktime in pktimes_savenames.keys():

        mplotname = os.path.join(args.outdir,f'det{detid}-pktime{pktime}')
        command = f'montage -geometry +4+4 -tile 4x2 '
        for gfx in pktimes_savenames[pktime]:
            command += f' {gfx} '
        command += f' {mplotname}'  

        logger.info(f'Executing {command}')
        proc = subprocess.Popen(shlex.split(command + '.pdf ')).communicate()
        proc = subprocess.Popen(shlex.split(command + '.png ')).communicate()
        proc = subprocess.Popen(shlex.split(command + '.jpg ')).communicate()
        command_shrink = f' convert {mplotname}.jpg -resize 50% -normalize  {mplotname}.jpg'
        proc = subprocess.Popen(shlex.split(command_shrink)).communicate()

    # let's write a textfile as well
    summaryfile = open(os.path.join(args.outdir, f'summary-{detid}.dat'), 'w')
    summaryfile.write(f'# detector {detid}\n')
    summaryfile.write(f'# strip - peaking time  - fwhm - fwhm_err\n')
    # FIXME - this fails in case we have more than one
    # detector, in that case however the whole script will blow up
    for ds in datasets:
        summaryfile.write(f'{detid}\t{get_stripname(ds.channel)}\t{ds.ptime:4.2f}\t{ds.fwhm:4.2f}\t{ds.fwhm_err:4.2f}\n')
    summaryfile.close()

    print (f'Folder {args.indir} analyzed')
