#!/usr/bin/env python
"""Process a single run with straxen
"""
import argparse
import datetime
import logging
import time
import os
import os.path as osp
import platform
import psutil
import sys


def parse_args():
    parser = argparse.ArgumentParser(
        description='Process a single run with straxen',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        'run_id',
        metavar='RUN_ID',
        type=str,
        help="ID of the run to process; usually the run name.")
    parser.add_argument(
        '--context',
        default='xenon1t_dali',
        help="Name of straxen context to use")
    parser.add_argument(
        '--target',
        default='event_info',
        help='Target final data type to produce')
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help='Start processing at raw_records, regardless of what data is available. '
             'Saving will ONLY occur to ./strax_data! If you already have the target'
             'data in ./strax_data, you need to delete it there first.')
    parser.add_argument(
        '--max_messages',
        default=4, type=int,
        help=("Size of strax's internal mailbox buffers. "
              "Lower to reduce memory usage, at increasing risk of deadlocks."))
    parser.add_argument(
        '--timeout',
        default=None, type=int,
        help="Strax' internal mailbox timeout in seconds")
    parser.add_argument(
        '--workers',
        default=1, type=int,
        help=("Number of worker threads/processes. "
              "Strax will multithread (1/plugin) even if you set this to 1."))
    parser.add_argument(
        '--notlazy',
        action='store_true',
        help='Forbid lazy single-core processing. Not recommended.')
    parser.add_argument(
        '--multiprocess',
        action='store_true',
        help="Allow multiprocessing.")
    parser.add_argument(
        '--shm',
        action='store_true',
        help="Allow passing data via /dev/shm when multiprocessing.")
    parser.add_argument(
        '--profile_to',
        default='',
        help="Filename to output profile information to. If ommitted,"
             "no profiling will occur.")
    parser.add_argument(
        '--profile_ram',
        action='store_true',
        help="Use memory_profiler for a more accurate measurement of the "
             "peak RAM usage of the process.")
    parser.add_argument(
        '--diagnose_sorting',
        action='store_true',
        help="Diagnose sorting problems during processing")
    parser.add_argument(
        '--debug',
        action='store_true',
        help="Enable debug logging to stdout")
    parser.add_argument(
        '--build_lowlevel',
        action='store_true',
        help='Build low-level data even if the context forbids it.')
    parser.add_argument(
        '--only_strax_data',
        action='store_true',
        help='Only use ./strax_data (if not on dali).')
    return parser.parse_args()


def main(args):
    logging.basicConfig(
        level=logging.DEBUG if args.debug else logging.INFO,
        format='%(asctime)s - %(threadName)s - %(name)s - %(levelname)s - %(message)s')

    print(f"Starting processing of run {args.run_id} until {args.target}")
    print(f"\tpython {platform.python_version()} at {sys.executable}")

    # These imports take a bit longer, so it's nicer
    # to do them after argparsing (so --help is fast)
    import strax
    print(f"\tstrax {strax.__version__} at {osp.dirname(strax.__file__)}")
    import straxen
    print(f"\tstraxen {straxen.__version__} at {osp.dirname(straxen.__file__)}")

    st = getattr(straxen.contexts, args.context)()
    if args.diagnose_sorting:
        st.set_config(dict(diagnose_sorting=True))
    st.context_config['allow_multiprocess'] = args.multiprocess
    st.context_config['allow_shm'] = args.shm
    st.context_config['allow_lazy'] = not (args.notlazy is True)
    if args.timeout is not None:
        st.context_config['timeout'] = args.timeout
    st.context_config['max_messages'] = args.max_messages
    if args.build_lowlevel:
        st.context_config['forbid_creation_of'] = tuple()

    if args.from_scratch:
        for q in st.storage:
            q.take_only = ('raw_records',)
        st.storage.append(
            strax.DataDirectory('./strax_data',
                                overwrite='always',
                                provide_run_metadata=False))
    if args.only_strax_data:
        st.storage = [
            strax.DataDirectory('./strax_data',
                                # overwrite='always',
                                provide_run_metadata=True)]
    if st.is_stored(args.run_id, args.target):
        print("This data is already available.")
        return 1

    md = st.run_metadata(args.run_id)
    t_start = md['start'].replace(tzinfo=datetime.timezone.utc).timestamp()
    t_end = md['end'].replace(tzinfo=datetime.timezone.utc).timestamp()
    run_duration = t_end - t_start

    st.config['run_start_time'] = md['start'].timestamp()
    st.context_config['free_options'] = tuple(
        list(st.context_config['free_options']) + ['run_start_time'])

    process = psutil.Process(os.getpid())
    peak_ram = 0

    def get_results():
        kwargs = dict(
            run_id=args.run_id,
            targets=args.target,
            max_workers=int(args.workers))

        if args.profile_to:
            with strax.profile_threaded(args.profile_to):
                yield from st.get_iter(**kwargs)
        else:
            yield from st.get_iter(**kwargs)

    clock_start = None
    for i, d in enumerate(get_results()):
        mem_mb = process.memory_info().rss / 1e6
        peak_ram = max(mem_mb, peak_ram)

        if not len(d):
            print(f"Got chunk {i}, but it is empty! Using {mem_mb:.1f} MB RAM.")
            continue

        # Compute detector/data time left
        t = d.end / 1e9
        dt = t - t_start
        time_left = t_end - t

        msg = (f"Got {len(d)} items. "
               f"Now {dt:.1f} sec / {100 * dt / run_duration:.1f}% into the run. "
               f"Using {mem_mb:.1f} MB RAM. ")
        if clock_start is not None:
            # Compute processing job clock time left
            d_clock = time.time() - clock_start
            clock_time_left = time_left / (dt / d_clock)
            msg += f"ETA {clock_time_left:.2f} sec."
        else:
            clock_start = time.time()

        print(msg, flush=True)

    print(f"\nStraxer is done! "
          f"We took {time.time() - clock_start:.1f} seconds, "
          f"peak RAM usage was around {peak_ram:.1f} MB.")


if __name__ == '__main__':
    args = parse_args()
    if args.profile_ram:
        from memory_profiler import memory_usage

        mem = memory_usage(proc=(main, [args], dict()))
        print(f"Memory profiler says peak RAM usage was: {max(mem):.1f} MB")
    else:
        sys.exit(main(args))
