#!/usr/bin/env python
import argparse
import pymongo
import hug
import strax
import straxen
import os
import socket
from datetime import datetime, timedelta
import pytz
import time
import threading

st: strax.Context
max_load_mb = 1000
max_return_mb = 10


def load_context(name, extra_dirs=None):
    global st
    st = getattr(straxen.contexts, name)()
    for sf in st.storage:
        sf.readonly = True

    if extra_dirs is not None:
        for d in extra_dirs:
            if os.access(d, os.W_OK) is not True:
                raise IOError(f'No writing access to {d}. Check mountpoints ')
            st.storage.append(strax.DataDirectory(d, readonly=True))

    st.context_config['allow_incomplete'] = True
    st.context_config['forbid_creation_of'] = (
            '*',
            # In case someone has an old strax that doesn't support * here:
            'raw_records', 'records', 'peaklets',
            'peaks', 'event_basics', 'event_info')


@hug.exception(Exception)
def handle_exception(exception):
    return {'error': str(exception)}


@hug.get('/get_data')
def get_data(
        run_id: hug.types.text,
        target: hug.types.text,
        max_n: hug.types.number = 1000,
        start: hug.types.float_number = None,
        end: hug.types.float_number = None,
        selection_str: hug.types.text = None):
    try:
        st
    except NameError:
        return "Context not loaded???"

    if len(run_id) != 6 and run_id.isdecimal():
        run_id = '%06d' % int(run_id)

    if max_n <= 0:
        max_n = None
    t0, _ = st.estimate_run_start_and_end(run_id, target)
    if start is None or end is None:
        time_range = None
    else:
        time_range = [t0 + int(start * int(1e9)),
                      t0 + int(end * int(1e9))]

    md = st.get_metadata(run_id, target)   # don't catch exception, already nice
    if not md['chunks']:
        raise ValueError("No chunks available -- either the first chunk has "
                         "yet to be written, or something is wrong.")
    full_run_bytes = st.size_mb(run_id, target) * int(1e6)

    if time_range is None:
        # Time range is the full run. Find the second range corresponding
        # to that (as best we can)
        if not 'chunks' in md or not md['chunks']:
            print(f"Run {run_id} has no chunks??")
        else:
            time_range = [md['chunks'][0]['start'],
                          md['chunks'][-1]['end']]

    if time_range is None or not max_n:
        # We have to load all the data
        load_bytes = full_run_bytes

    else:
        # Let's see how much we really need to load, given max_n.
        load_at_least_n = 0
        load_bytes = 0
        for chunk_info in md['chunks']:
            if (chunk_info['start'] > time_range[1]
                    or chunk_info['end'] < time_range[0]):
                # None of this chunk will be loaded
                continue
            load_bytes += chunk_info['nbytes']

            if (chunk_info['start'] > time_range[0]
                    and chunk_info['end'] < time_range[1]):
                # All of this chunk will be loaded
                load_at_least_n += chunk_info['n']

            if load_at_least_n > max_n:
                # No point loading any more, we already exceeded max_n
                time_range[1] = chunk_info['end']
                break

    if load_bytes > max_load_mb * int(1e6):
        raise ValueError(f"Cannot load {target} for run {run_id}: it would "
                         f"take more than {max_load_mb} MB RAM to load.")

    x = st.get_array(run_id, target,
                     time_range=time_range,
                     selection_str=selection_str)
    if max_n:
        x = x[:max_n]
    size_mb = x.nbytes / int(1e6)
    if size_mb > max_return_mb:
        raise ValueError(f"Not converting {target} for run {run_id} "
                         f"to json since the binary data is {size_mb:.1f} MB. "
                         f"Try lowering max_n.")

    return [dict(zip(row.dtype.names, row)) for row in x]


def now(plus=0):
    return datetime.now(pytz.utc) + timedelta(seconds=plus)


def _set_state():
    """
    Inform the bootstrax collection we're running microstrax
    """
    # DAQ database
    daq_db_name = 'daq'
    daq_uri = straxen.get_mongo_uri(header='rundb_admin',
                                    user_key='mongo_daq_username',
                                    pwd_key='mongo_daq_password',
                                    url_key='mongo_daq_url')
    daq_client = pymongo.MongoClient(daq_uri)
    daq_db = daq_client[daq_db_name]
    bs_coll = daq_db['eb_monitor']
    while True:
        microstrax_state = dict(
            host=socket.getfqdn(),
            pid=os.getpid(),
            time=now(),
            state='hosting microstrax',
            max_load_mb=args.max_load_mb,
            max_return_mb=args.max_return_mb,
            context=args.context,
            n_dirs=len(args.extra_dirs)
        )
        try:
            daq_db.command('ping')
        except Exception as timeout_error:
            print(f"Ran into {timeout_error}. Can be bad. Let's take a nap")
        bs_coll.insert_one(microstrax_state)
        time.sleep(9)


def set_state():
    """
    Open thread to set the state of microstrax
    """
    update_thread = threading.Thread(name='Update_db', target=_set_state)
    update_thread.setDaemon(True)
    update_thread.start()


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description='Start a microservice to return strax data as json')
    parser.add_argument('--context', default='xenonnt_online',
                        help='Name of straxen context to use')
    parser.add_argument('--port', default=8000, type=int,
                        help='HTTP port to serve on')
    parser.add_argument('--max_load_mb', default=1000, type=int)
    parser.add_argument('--max_return_mb', default=10, type=int)
    parser.add_argument('--extra_dirs', nargs='*',
                        help="Extra directories to look for data")
    args = parser.parse_args()

    max_load_mb = args.max_load_mb
    max_return_mb = args.max_return_mb

    load_context(args.context, extra_dirs=args.extra_dirs)
    set_state()

    while True:
        try:
            hug.API(__name__).http.serve(port=args.port)
        except pymongo.errors.NotMasterError:
            print('Ran into the infamous pymongo.errors.NotMasterError. Sleep for a sec')
            time.sleep(60)
