#!/usr/bin/env python
"""Process a single run with straxen
"""
import argparse
import datetime
import logging
import time
import os
import psutil
import json
import importlib
import sys


def parse_args():
    parser = argparse.ArgumentParser(
        description='Process a single run with straxen',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        'run_id',
        metavar='RUN_ID',
        type=str,
        help="ID of the run to process; usually the run name.")
    parser.add_argument(
        '--package',
        default='straxen',
        help="Where to load the context from (straxen/cutax/pema)")
    parser.add_argument(
        '--context',
        default='xenonnt_online',
        help="Name of context to use")
    parser.add_argument(
        '--target',
        default='event_info',
        nargs='*',
        help='Target final data type to produce. Can be a list for multicore mode.')
    parser.add_argument(
        '--context_kwargs',
        type=json.loads,
        help='Use a json-file to load the context with')
    parser.add_argument(
        '--register_from_file',
        type=str,
        help='do st.register_all from a specified file'
    )
    parser.add_argument(
        '--config_kwargs',
        type=json.loads,
        help='Use a json-file to set the context to')
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help='Start processing at raw_records, regardless of what data is available. '
             'Saving will ONLY occur to ./strax_data! If you already have the target'
             'data in ./strax_data, you need to delete it there first.')
    parser.add_argument(
        '--max_messages',
        default=4, type=int,
        help=("Size of strax's internal mailbox buffers. "
              "Lower to reduce memory usage, at increasing risk of deadlocks."))
    parser.add_argument(
        '--timeout',
        default=None, type=int,
        help="Strax' internal mailbox timeout in seconds")
    parser.add_argument(
        '--workers',
        default=1, type=int,
        help=("Number of worker threads/processes. "
              "Strax will multithread (1/plugin) even if you set this to 1."))
    parser.add_argument(
        '--notlazy',
        action='store_true',
        help='Forbid lazy single-core processing. Not recommended.')
    parser.add_argument(
        '--multiprocess',
        action='store_true',
        help="Allow multiprocessing.")
    parser.add_argument(
        '--multi_target',
        action='store_true',
        help='Allow st.make to be called with multiple targets at once '
             '(otherwise loop over the target list)')
    parser.add_argument(
        '--shm',
        action='store_true',
        help="Allow passing data via /dev/shm when multiprocessing.")
    parser.add_argument(
        '--profile_to',
        default='',
        help="Filename to output profile information to. If omitted,"
             "no profiling will occur.")
    parser.add_argument(
        '--profile_ram',
        action='store_true',
        help="Use memory_profiler for a more accurate measurement of the "
             "peak RAM usage of the process.")
    parser.add_argument(
        '--diagnose_sorting',
        action='store_true',
        help="Diagnose sorting problems during processing")
    parser.add_argument(
        '--debug',
        action='store_true',
        help="Enable debug logging to stdout")
    parser.add_argument(
        '--build_lowlevel',
        action='store_true',
        help='Build low-level data even if the context forbids it.')
    parser.add_argument(
        '--only_strax_data',
        action='store_true',
        help='Only use ./strax_data (if not on dali).')
    parser.add_argument(
        '--add_folder',
        default='',
        help='Also add folder to st.storage')
    parser.add_argument(
        '--print_alive',
        default=300,
        help='Print that straxer is still running every this many [seconds]')
    return parser.parse_args()


def setup_context(args):
    # reimport to be safe
    import strax
    import straxen

    context_module = importlib.import_module(f'{args.package}.contexts')
    st = getattr(context_module, args.context)()

    if args.context_kwargs:
        logging.info(f'set context kwargs {args.context_kwargs}')
        st = getattr(context_module, args.context)(**args.context_kwargs)

    if args.config_kwargs:
        logging.info(f'set context options to {args.config_kwargs}')
        st.set_config(to_dict_tuple(args.config_kwargs))

    if args.register_from_file:
        register_to_context(st, args.register_from_file)

    if args.diagnose_sorting:
        st.set_config(dict(diagnose_sorting=True))

    st.context_config['allow_multiprocess'] = args.multiprocess
    st.context_config['allow_shm'] = args.shm
    st.context_config['allow_lazy'] = not (args.notlazy is True)

    if args.timeout is not None:
        st.context_config['timeout'] = args.timeout
    st.context_config['max_messages'] = args.max_messages

    if args.build_lowlevel:
        st.context_config['forbid_creation_of'] = tuple()
    else:
        st.context_config['forbid_creation_of'] = straxen.DAQReader.provides

    if args.from_scratch:
        for q in st.storage:
            q.take_only = ('raw_records',)
        st.storage.append(
            strax.DataDirectory('./strax_data',
                                overwrite='always',
                                provide_run_metadata=False))
    if args.only_strax_data:
        for sf in st.storage:
            # Set all others to read only
            sf.readonly = True
        for sf in st.storage:
            if hasattr(sf, 'path'):
                if sf.path == './strax_data':
                    break
        else:
            st.storage += [
                strax.DataDirectory('./strax_data')]

    if args.add_folder != '':
        for sf in st.storage:
            # Set all others to read only
            sf.readonly = True
        if os.path.exists(args.add_folder):
            st.storage += [strax.DataDirectory(args.add_folder)]

    if st.is_stored(args.run_id, args.target):
        logging.warning("This data is already available. Straxer is done")
        sys.exit(0)
    return st


def main(args):
    logging.basicConfig(
        level=logging.DEBUG if args.debug else logging.INFO,
        format='%(asctime)s - %(threadName)s - %(name)s - %(levelname)s - %(message)s')

    logging.info(f"Starting processing of run {args.run_id} until {args.target}")

    # These imports take a bit longer, so it's nicer
    # to do them after argparsing (so --help is fast)
    import strax
    import straxen
    import pandas as pd
    # For showing data availability below
    pd.options.display.max_rows = 999
    logging.info(straxen.print_versions(
        tuple({'strax', 'straxen', args.package}),
        return_string=True,
        print_output=False)
        ,)

    logging.info('Starting context')
    st = setup_context(args)
    logging.info('Context started')

    # Reactivate after https://github.com/XENONnT/straxen/issues/586
    logging.info(f'Checking availabilty')
    logging.info(f'Available\n{str(st.available_for_run(args.run_id))}')

    logging.info('Infer start/end')
    try:
        md = st.run_metadata(args.run_id)
        t_start = md['start'].replace(tzinfo=datetime.timezone.utc).timestamp()
        t_end = md['end'].replace(tzinfo=datetime.timezone.utc).timestamp()
        st.config['run_start_time'] = md['start'].timestamp()
        st.context_config['free_options'] = tuple(
            list(st.context_config['free_options']) + ['run_start_time'])
    except strax.RunMetadataNotAvailable:
        for t in ('raw_records', 'records', 'peaklets'):
            if st.is_stored(args.run_id, t):
                break
        t_start, t_end = st.estimate_run_start_and_end(args.run_id, t)
        t_start, t_end = t_start/1e9, t_end/1e9
    run_duration = t_end - t_start
    logging.info(f'Infer start/end: run is {run_duration:.1f} s')

    start_alive_thread(logging, args.print_alive)

    process = psutil.Process(os.getpid())
    peak_ram = 0

    def run_make(targets):
        """For target or a list of targets, run st.make and print the progress"""
        nonlocal peak_ram
        if not args.multi_target:
            logging.info(f'Checking if {targets} is stored')
            if st.is_stored(args.run_id, targets):
                logging.info(f'{args.run_id}:{targets} is stored')
                return
        logging.info(f'Start processing {args.run_id}:{targets}')

        def get_results():
            kwargs = dict(
                run_id=args.run_id,
                targets=targets,
                max_workers=int(args.workers),
                allow_multiple=args.multi_target,
                progress_bar=False,
                save=strax.to_str_tuple(targets),
            )

            if args.profile_to:
                with strax.profile_threaded(args.profile_to):
                    yield from st.get_iter(**kwargs)
            else:
                yield from st.get_iter(**kwargs)

        clock_start = None
        for i, d in enumerate(get_results()):
            mem_mb = process.memory_info().rss / 1e6
            peak_ram = max(mem_mb, peak_ram)

            if not len(d):
                logging.info(f"Got chunk {i}, but it is empty! Using {mem_mb:.1f} MB RAM.")
                continue

            # Compute detector/data time left
            time_end = d.end / 1e9
            dt = time_end - t_start
            time_left = t_end - time_end

            msg = (f"Got {len(d)} items. "
                   f"Now {dt:.1f} sec / {100 * dt / run_duration:.1f}% into the run. "
                   f"Using {mem_mb:.1f} MB RAM. ")
            if clock_start is not None:
                # Compute processing job clock time left
                d_clock = time.time() - clock_start
                clock_time_left = time_left / (dt / d_clock)
                msg += f"ETA {clock_time_left:.2f} sec."
            else:
                clock_start = time.time()
            logging.info(msg)
        logging.info(f"{targets} finished! "
                     f"Took {time.time() - clock_start:.1f} s")
    proc_start = time.time()
    if args.multi_target:
        run_make(args.target)
    else:
        for target in strax.to_str_tuple(args.target):
            run_make(target)
    logging.info(f"Straxer is done in {time.time()-proc_start :.1f} s! peak RAM usage was ~{peak_ram:.1f} MB.")


def register_to_context(st, module: str):
    if not os.path.exists(module):
        raise FileNotFoundError(f'No such file {module}')
    assert module.endswith('.py'), "only py files please!"
    folder, file = os.path.split(module)
    sys.path.append(folder)
    to_register = importlib.import_module(os.path.splitext(file)[0])
    st.register_all(to_register)
    logging.info(f'Successfully registered {file}. Printing plugins')

    for key, plugin in st._plugin_class_registry.items():
        logging.info(f'{key}\t{plugin}')


def to_dict_tuple(res: dict):
    """Convert list configs to tuple configs"""
    res = res.copy()
    for k, v in res.copy().items():
        if type(v) == list:
            # Remove lists to tuples
            res[k] = tuple(_v if type(_v) != list else tuple(_v) for _v in v)
    return res


def print_is_alive(log, print_timeout):
    while True:
        log.info('Straxer still running')
        time.sleep(print_timeout)


def start_alive_thread(log, print_timeout):
    from threading import Thread
    thread = Thread(name='Ping alive',
                    target=print_is_alive,
                    args=(log, print_timeout),
                    daemon=True)
    log.info(f'Starting thread to ping that we are still running')
    thread.start()


if __name__ == '__main__':
    args = parse_args()
    if args.profile_ram:
        from memory_profiler import memory_usage

        mem = memory_usage(proc=(main, [args], dict()))
        print(f"Memory profiler says peak RAM usage was: {max(mem):.1f} MB")
        sys.exit()
    else:
        sys.exit(main(args))
