#!/usr/bin/env python
"""
Bootstrax: XENONnT online processing manager
=============================================
First draft: Jelle Aalbers, 2018
With additional input from: Joran Angevaare, 2020

This script watches for new runs to appear from the DAQ, then starts a strax process to process them. If a run fails, it will retry it with exponential backoff.

You can run more than one bootstrax instance, but only one per machine. If you start a second one on the same machine, it will try to kill the first one.


Philosophy
-------------
Bootstrax has a crash-only / recovery first philosophy. Any error in the core code causes a crash; there is no nice exit or mandatory cleanup. Boostrax focuses on recovery after restarts: before starting work, we look for and fix any mess left by crashes.

This ensures that hangs and hard crashes do not require expert tinkering to repair databases. Plus, you can just stop the program with ctrl-c (or, in principle, pulling the machine's power plug) at any time.

Errors during run processing are assumed to be retry-able. We track the number of failures per run to decide how long to wait until we retry; only if a user marks a run as 'abandoned' (using an external system, e.g. the website) do we stop retrying.


Mongo documents
----------------
Bootstrax records its status in a document in the 'bootstrax' collection in the runs db. These documents contain:
  - **host**: socket.getfqdn()
  - **time**: last time this bootstrax showed lifesigns
  - **state**: one of the following:
    - **busy**: doing something
    - **idle**: NOT doing something; available for processing new runs

Additionally, bootstrax tracks information with each run in the 'bootstrax' field of the run doc. We could also put this elsewhere, but it seemed convenient. This field contains the following subfields:
  - **state**: one of the following:
    - **considering**: a boostrax is deciding what to do with it
    - **busy**: a strax process is working on it
    - **failed**: something is wrong, but we will retry after some amount of time
    - **abandoned**: bootstrax will ignore this run
  - **reason**: reason for last failure, if there ever was one (otherwise this field does not exists). Thus, it's quite possible for this field to exist (and show an exception) when the state is 'done': that just means it failed at least once but succeeded later. Tracking failure history is primarily the DAQ log's reponsibility; this message is only provided for convenience.
   - **n_failures**: number of failures on this run, if there ever was one (otherwise this field does not exist).
  - **next_retry**: time after which bootstrax might retry processing this run. Like 'reason', this will refer to the last failure.
Finally, bootstrax outputs the load on the eventbuilder machine(s) whereon it is running to a collection in the DAQ database into the capped collection 'eb_monitor'. This collection contains the following fields:
  - **run**: run number,
  - **host**: the name of the eventbuilder,
  - **time**: the time whereon the state was reported,
  - **pid**: the pid of this bootstrax (sub)process,
  - **cpu_pid**: cpu usage of the bootstrax subprocess (in percent),
  - **cpu_tot**: cpu usage on the eventbuilder (in percent),
  - **mem_pid**: memory used by this pid of boostrax (in percent),
  - **mem_tot**': memory used on this eventbuilder (in percent),
  - **disk_size**: size of the disk whereto this bootstrax instance is writing to (in TB),
  - **disk_used**: used part of the disk whereto this bootstrax instance is writing to (in percent)
"""

import argparse
from collections import Counter
from datetime import datetime, timedelta
import logging
import multiprocessing
import npshmex
import os
import os.path as osp
import signal
import socket
import shutil
import time
import traceback
from tqdm import tqdm

import numpy as np
import pymongo
from psutil import pid_exists, Process, virtual_memory, cpu_percent, disk_usage
import pytz
import strax
import straxen

logging.basicConfig(level=logging.INFO,
                    format='%(relativeCreated)6d %(threadName)s %(name)s %(message)s')

parser = argparse.ArgumentParser(
    description="XENONnT online processing manager")
parser.add_argument('--debug', action='store_true',
                    help="Start strax processes with debug logging.")
parser.add_argument('--profile', type=str, default='false',
                    help="Option to run strax in profiling mode. "
                         "argument specifies the name of the profile if not 'false'. Use e.g. 'date'.prof")
parser.add_argument('--cores', type=int, default=25,
                    help="Maximum number of workers to use in a strax process. "
                         "Set to -1 for all available cores")
parser.add_argument('--target', default='event_info',
                    help="Strax data type name that should be produced")
parser.add_argument('--infer_run_mode', action='store_true',
                    help="Determine best number max-messages and cores for each run automatically. "
                         "Overrides --cores and --max_messages")
# TODO change to action
parser.add_argument('--delete_live', type=str,
                    # TODO -> set this to yes after testing
                    default='no',
                    help="Delete live_data after succesful processing of the run [yes/no]")
parser.add_argument('--max_messages', type=int,
                    default=4,
                    help="number of max mailbox messages")

actions = parser.add_mutually_exclusive_group()
actions.add_argument('--process', type=int, metavar='NUMBER',
                     help="Process a single run, regardless of its status.")
actions.add_argument('--fail', nargs='+',
                     metavar=('NUMBER', 'REASON'),
                     help="Fail run number, optionally with reason")
actions.add_argument('--abandon', nargs='+',
                     metavar=('NUMBER', 'REASON'),
                     help="Abandon run number, optionally with reason")

args = parser.parse_args()

##
# Configuration
##

print(f'---\nTEST VERSION 0.3.0\n---')

# The event builders write to different directories on the respective machines.
eb_directories = {
    'eb0.xenon.local': '/data2/xenonnt_processed/',
    'eb1.xenon.local': '/data1/xenonnt_processed/',
    'eb2.xenon.local': '/nfs/eb0_data1/xenonnt_processed/',
    'eb3.xenon.local': '/data/xenonnt_processed/',
    'eb4.xenon.local': '/data/xenonnt_processed/',
    'eb5.xenon.local': '/data/xenonnt_processed/',
}

# Timeouts in seconds
timeouts = {
    # Waiting between escalating SIGTERM -> SIGKILL -> crashing bootstrax
    # when trying to kill another process (usually child strax)
    'signal_escalate': 3,
    # Minimum waiting time to retry a failed run
    # Escalates exponentially on repeated failures: 1x, 5x, 25x, 125x, 125x, 125x, ...
    # Some jitter is applied: actual delays willrandomly be 0.5 - 1.5x as long
    'retry_run': 60,
    # Maximum time for strax to complete a processing
    # if exceeded, strax will be killed by bootstrax
    # TODO
    #  Do we want to have a higer number for e.g. calibration more runs?
    'max_processing_time': 7200,
    # Sleep between checking whether a strax process is alive
    'check_on_strax': 10,
    # Maximum time a run is 'busy' without a further update from
    # its responsible bootstrax. Bootstrax normally updates every
    # check_on_strax seconds, so make sure this is substantially
    # larger than check_on_strax.
    'max_busy_time': 120,
    # Maximum time a run is in the 'considering' state
    # if exceeded, will be labeled as an untracked failure
    'max_considering_time': 60,
    # Minimum time to wait between database cleanup operations
    'cleanup_spacing': 60,
    # Sleep time when there is nothing to do
    'idle_nap': 10,
    # If we don't hear from a bootstrax on another host for this long,
    # remove its entry from the bootstrax status collection
    # Must be much longer than idle_nap and check_on_strax!
    'bootstrax_presumed_dead': 300
}

# The disk that the eb is writing to may fill up at some point. The data should
# be written to datamanager at some point. This may clean up data on the disk,
# hence, we can check if there is sufficient diskspace and if not, wait a while.
# Below are the max number of times and number of seconds bootstrax will wait.
wait_diskspace_min_space = 500  # GB
wait_diskspace_n_max = 100  # times
wait_diskspace_dt = 60  # seconds
assert timeouts['bootstrax_presumed_dead'] > wait_diskspace_dt, "wait_diskspace_dt too large"

# Fields in the run docs that bootstrax uses
bootstrax_projection = f"name start end number bootstrax " \
                       f"data.host data.type data.location " \
                       f"ini.processing_threads ini.compressor ini.strax_fragment_payload_bytes " \
                       f"ini.strax_chunk_length ini.strax_chunk_overlap".split()

# Filename for temporary storage of the exception
# This is used to communicate the exception from the strax child process
# to the bootstrax main process
exception_tempfile = 'last_bootstrax_exception.txt'

##
# Initialize globals (e.g. rundb connection)
##

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(name)s %(levelname)-8s %(message)s',
    datefmt='%m-%d %H:%M')
log = logging.getLogger()
hostname = socket.getfqdn()
state_doc_id = None  # Set in main loop

# Set the output folder
output_folder = eb_directories[hostname]
if os.access(output_folder, os.W_OK) is not True:
    raise IOError(f'No writing access to {output_folder}')

if args.delete_live.lower() in ['y', 'yes', 'true']:
    delete_live = True
else:
    delete_live = False
    print(f'Not deleting live data because --delete_live = {args.delete_live}')


def new_context(cores=args.cores, max_messages=args.max_messages):
    """Create strax context that can access the runs db"""
    # We use exactly the logic of straxen to access the runs DB;
    # this avoids duplication, and ensures strax can access the runs DB if we can
    return straxen.contexts.xenonnt_online(
        output_folder=output_folder,
        we_are_the_daq=True,
        allow_multiprocess=cores > 1,
        allow_shm=cores > 1,
        max_messages=max_messages)


st = new_context()

# DAQ database
daq_db_name = 'daq'
daq_db_password = os.environ['MONGO_DAQ_PASSWORD']
daq_db_username = os.environ['MONGO_DAQ_USERNAME']
daq_client = pymongo.MongoClient(f"mongodb://{daq_db_username}:{daq_db_password}@xenon1t-daq:27020,old-gw:27020/daq")
daq_db = daq_client[daq_db_name]
usage_coll = daq_db['eb_monitor']

# Runs database
run_dbname = 'run'
run_collname = 'runs'
run_db = st.storage[0].client[run_dbname]
run_coll = run_db[run_collname]
bs_coll = run_db['bootstrax']
log_coll = run_db['log']

# Ping the databases to ensure the mongo connections are working
run_db.command('ping')
daq_db.command('ping')


def main():
    if args.cores == -1:
        # Use all of the available cores on this machine
        args.cores = multiprocessing.cpu_count()
        print(f'Set cores to n_tot, using {args.cores} cores')

    if args.fail:
        args.fail += ['']  # Provide empty reason if none specified
        manual_fail(number=int(args.fail[0]), reason=args.fail[1])

    elif args.abandon:
        number = int(args.abandon[0])
        if len(args.abandon) > 1:
            manual_fail(number=number, reason=args.abandon[1])
        abandon(number=number)

    elif args.process:
        t_start = now()
        number = args.process
        rd = consider_run({'number': number})
        if rd is None:
            raise ValueError(f"No run numbered {number} exists")
        process_run(rd)
        print(f'bootstrax ({hostname}) finished run {number} in {(now() - t_start).seconds} seconds')
    else:
        # Start processing
        main_loop()


##
# Main loop
##

def main_loop():
    """Infinite loop looking for runs to process"""
    global state_doc_id

    # Ensure we're the only bootstrax on this host
    any_running = list(bs_coll.find({'host': hostname}))
    for x in any_running:
        if pid_exists(x['pid']):
            log.warning(f'Bootstrax already running with PID {x["pid"]}, trying to kill it')
            kill_process(x['pid'])
        bs_coll.delete_one({'_id': x['_id']})

    # Register ourselves
    state_doc_id = bs_coll.insert_one(dict(
        host=hostname,
        started=now(),
        pid=os.getpid())).inserted_id
    set_state('starting')
    t_start = now()

    next_cleanup_time = now()
    while True:
        print(f'bootstrax running for {(now() - t_start).seconds} seconds')
        sufficient_diskspace()
        log.info("Looking for work")

        set_state('busy')
        # Check resources are still OK, otherwise crash / reboot program

        # Process new runs
        rd = consider_run({"bootstrax": {"$exists": False}})
        if rd is not None:
            process_run(rd)
            continue

        # Scan DB for runs with unusual problems
        if now() > next_cleanup_time:
            cleanup_db()
            next_cleanup_time = now(plus=timeouts['cleanup_spacing'])

        # Any failed runs to retry?
        # Only try one run, we want to be back for new runs quickly
        rd = consider_run({
            "bootstrax.state": 'failed',
            "bootstrax.next_retry": {
                '$lt': now()}})
        if rd is not None:
            process_run(rd)
            continue

        log.info("No work to do, waiting for new runs or retry timers")
        set_state('idle')
        time.sleep(timeouts['idle_nap'])


##
# General helpers
##

def now(plus=0):
    return datetime.now(pytz.utc) + timedelta(seconds=plus)


def kill_process(pid, wait_time=None):
    """Kill process pid, or raise RuntimeError if we cannot
    :param wait_time: time to wait before escalating signal strength
    """
    if wait_time is None:
        wait_time = timeouts['signal_escalate']
    if not pid_exists(pid):
        return

    for sig in [signal.SIGTERM, signal.SIGKILL, 'die']:
        time.sleep(wait_time)
        if not pid_exists(pid):
            return
        if signal == 'die':
            raise RuntimeError(f"Could not kill process {pid}")
        os.kill(pid, sig)


def update_usage(pid_process, run_number):
    """Save information on the usage of resources in the 'eb_monitor' collection.
    The information contains percentages of CPU and Memory usage both for the
    sub-process (due to the multiprocessing) itself and on the host as a whole.
    """
    frac_to_percent = 100
    mem_tot = virtual_memory()
    disk = disk_usage(output_folder)
    terabyte_to_byte = 1024.0 ** 4
    #   current_process =  multiprocessing.current_process()
    usage = {'run': int(run_number),
             'host': hostname,
             'time': now(),
             'pid': pid_process.pid,
             'cpu_pid': pid_process.cpu_percent(),
             'cpu_tot': cpu_percent(),
             'mem_pid': pid_process.memory_percent(),
             'mem_tot': frac_to_percent * mem_tot.used / mem_tot.total,
             'disk_size': disk.total / terabyte_to_byte,
             'disk_used': disk.percent
             }
    usage_coll.insert_one(usage)


def set_state(state):
    """Inform the bootstrax collection we're in a different state

    if state is None, leave state unchanged, just update heartbeat time
    """

    if state_doc_id is None:
        log.debug("Not updating bootstrax doc: none selected. "
                  "You are probably running a rogue bootstrax operation...")
        return
    update = {'time': now()}
    if state is not None:
        update['state'] = state
    bs_coll.update_one(
        {'_id': state_doc_id},
        {'$set': update})


def send_heartbeat():
    """Inform the bootstrax collection we're still here
    Use during long-running tasks where state doesn't change
    """
    # Same as set_state, just don't change state
    set_state(None)


def log_warning(message, priority='warning'):
    """Report a warning to the terminal (using the logging module)
    and the DAQ log DB.
    :param priority: severity of warning. Can be:
        info: 1,
        warning: 2,
        <any other valid python logging level, e.g. error or fatal>: 3
    """
    getattr(log, priority)(message)
    log_coll.insert_one({
        'message': message,
        'user': f'bootstrax_{hostname}',
        'priority': dict(warning=2, info=1).get(priority, 3)})


def infer_run_mode(rd, input_dir):
    '''
    Infer a safe operating mode of running bootstrax based on the size of the first chunk.
    Estimating save parameters for running bootstrax from:
      # TODO update with new strax version
      https://xe1t-wiki.lngs.infn.it/doku.php?id=xenon:xenonnt:dsg:daq:eb_speed_tests
    returns: dictionary of how many cores and max_messages should be used based on an
    estimated data rate.
    '''
    nap_time, i = 10, 0
    while len(os.listdir(input_dir)) < 3 * 3:
        if i > 10:
            # To few files to be sure the first chunk is completely written
            return dict(cores=args.cores, max_messages=args.max_messages)
        time.sleep(nap_time)
    # Estimate data rate on the bases of the size of the first chunk
    first_chunk = input_dir + '/000000/'
    chunk_size = np.sum([os.path.getsize(first_chunk + f) for f in os.listdir(first_chunk)])
    chunk_time = rd['ini']['strax_chunk_length'] + rd['ini']['strax_chunk_overlap']
    compression_factor = 3  # this is an estimate
    data_rate = compression_factor * (chunk_size / 1e6) / chunk_time  # in MB/s

    # Find out if eb is new (eb3-eb5):
    is_new_eb = int(hostname.strip('eb.xenon.local')) >= 3
    print(f"infer_run_mode::\teb number {int(hostname.strip('eb.xenon.local'))}")

    # Return a run-mode that is empirically found to provide stable processing.
    print(f'infer_run_mode::\tWorking with an estimated data rate of {data_rate} MB/s')
    if data_rate < 50:
        result = dict(cores=12 if is_new_eb else 8,
                      max_messages=8 if is_new_eb else 6)
    elif data_rate < 100:
        result = dict(cores=10 if is_new_eb else 6,
                      max_messages=8 if is_new_eb else 6)
    elif data_rate < 300:
        result = dict(cores=8 if is_new_eb else 5,
                      max_messages=6 if is_new_eb else 4)
    elif data_rate < 700:
        result = dict(cores=6 if is_new_eb else 4,
                      max_messages=4 if is_new_eb else 4)
    else:
        result = dict(cores=1,
                      max_messages=4)
    print(f'Override processing mode for run {rd["number"]} changing to:\t {result}')
    return result


##
# Host interactions
##

def sufficient_diskspace():
    '''Check if there is sufficient space available on the local disk to write to'''
    disk = disk_usage(output_folder)
    gigabyte_to_byte = 1024.0 ** 3
    disk_free = disk.free / gigabyte_to_byte
    print(f'Check diskspace. {int(disk_free)} GB free')
    for i in range(wait_diskspace_min_space):
        if disk_free > wait_diskspace_min_space:
            # Good we can write, lets continue
            return
        else:
            print(f'Insufficient free disk space ({disk.disk_free} GB)'
                  f'on {hostname}. Waiting {wait_diskspace_dt} s ({i}th iteration')
            time.sleep(wait_diskspace_dt)
            send_heartbeat()
    raise RuntimeError("No diskspace to write to")


def delete_live_data(live_data_path):
    '''After completing the processing and updating the runsDB, remove the
    live_data'''

    if os.path.exists(live_data_path) and delete_live:
        print(f'Deleteing data at {live_data_path}')
        shutil.rmtree(live_data_path)
        if os.path.exists(live_data_path):
            assert False, "we wanted to delete this path!"
    else:
        raise FileNotFoundError(f'Something is off, no folder found '
                                f'at {live_data_path} but there was '
                                f'an attempt to delete data here')


def clear_shm():
    '''Manually delete files in /dev/shm/ created by npshmex on starup.'''
    shm_dir = '/dev/shm/'
    shm_files = [f for f in os.listdir(shm_dir) if 'npshmex' in f]

    if not len(shm_files):
        return
    print(f'clear_shm:: clearing {len(shm_files)} files')
    for f in tqdm(shm_files):
        os.remove(shm_dir + f)


##
# Run DB interaction
##


def get_run(*, mongo_id=None, number=None, full_doc=False):
    """Find and return run doc matching mongo_id or number
    The bootstrax state is left unchanged.

    :param full_doc: If true (default is False), return the full run doc
        rather than just fields used by bootstrax.
    """
    if number is not None:
        query = {'number': number}
    elif mongo_id is not None:
        query = {'_id': mongo_id}
    else:
        raise ValueError("Please give mongo_id or number")
    try:
        return run_coll.find_one(query,
                                 projection=None if full_doc else bootstrax_projection)
    except pymongo.errors.AutoReconnect as e:
        # mongo-db seems to be overloaded, wait a second and try again
        auto_reconnect_nap = 60
        print(f'ran into {e}, take a nap for {auto_reconnect_nap} seconds and try again')
        time.sleep(auto_reconnect_nap)
        return run_coll.find_one(query,
                                 projection=None if full_doc else bootstrax_projection)


def set_run_state(rd, state, return_new_doc=True, **kwargs):
    """Set state of run doc rd to state
    return_new_doc: if True (default), returns new document.
        if False, instead returns the original (un-updated) doc.

    Any additional kwargs will be added to the bootstrax field.
    """
    bd = rd['bootstrax']
    bd.update({
        'state': state,
        'host': hostname,
        'time': now(),
        **kwargs})

    if state == 'failed':
        bd['n_failures'] = bd.get('n_failures', 0) + 1

    return run_coll.find_one_and_update(
        {'_id': rd['_id']},
        {'$set': {'bootstrax': bd}},
        return_document=return_new_doc,
        projection=bootstrax_projection)


def abandon(*, mongo_id=None, number=None):
    """Mark a run as abandoned"""
    set_run_state(
        get_run(mongo_id=mongo_id, number=number),
        'abandoned')


def consider_run(query, return_new_doc=True):
    """Return one run doc matching query, and simultaneously set its bootstraxstate to 'considering'"""
    # We must first do an atomic find-and-update to set the run's state
    # to "considering", to ensure the run doesn't get picked up by a
    # bootstrax on another host.
    rd = run_coll.find_one_and_update(
        query,
        {"$set": {'bootstrax.state': 'considering'}},
        projection=bootstrax_projection,
        return_document=True,
        sort=[('start', pymongo.DESCENDING)])

    # Next, we can update the bootstrax entry properly with set_run_state
    # (adding hostname, time, etc.)
    if rd is None:
        return None
    return set_run_state(rd, 'considering', return_new_doc=return_new_doc)


def fail_run(rd, reason):
    """Mark the run represented by run doc rd as failed with reason"""
    if 'number' not in rd:
        long_run_id = f"run <no run number!!?>:{rd['_id']}"
    else:
        long_run_id = f"run {rd['number']}:{rd['_id']}"

    # No bootstrax info is present when manually failing a run with args.fail
    if not 'bootstrax' in rd.keys():
        rd['bootstrax'] = {}
        rd['bootstrax']['n_failures'] = 0

    if 'n_failures' in rd['bootstrax'] and rd['bootstrax']['n_failures'] > 0:
        fail_name = 'Repeated failure'
        failure_message_level = 'info'
    else:
        fail_name = 'New failure'
        failure_message_level = 'warning'

    # Cleanup any data associated with the run
    # TODO: This should become optional, or just not happen at all,
    # after we're done testing (however, then we need some other
    # pruning mechanism)
    clean_run(mongo_id=rd['_id'])

    # Report to run db
    # It's best to do this after everything is done;
    # as it changes the run state back away from 'considering', so another
    # bootstrax could conceivably pick it up again.
    set_run_state(rd, 'failed',
                  reason=reason,
                  next_retry=(
                      now(plus=(timeouts['retry_run']
                                * np.random.uniform(0.5, 1.5)
                                # Exponential backoff with jitter
                                * 5 ** min(rd['bootstrax'].get('n_failures', 0), 3)
                                ))))

    # Report to DAQ log and screen
    log_warning(f"{fail_name} on {long_run_id}: {reason}",
                priority=failure_message_level)


def manual_fail(*, mongo_id=None, number=None, reason=''):
    """Manually mark a run as failed based on mongo_id or run number"""
    rd = get_run(mongo_id=mongo_id, number=number)
    fail_run(rd, "Manually set failed state. " + reason)


def get_compressor(rd, default_compressor="blosc"):
    """Read the compressor method from the run_doc. Return 'blosc' if no
    information is specified in the run_doc"""
    try:
        return rd["ini"]["compressor"]
    except KeyError:
        log_warning(f"Bootstrax couldn't read the compressor form the run_doc. "
                    f"Assuming 'blosc' for now")
        return default_compressor


##
# Processing
##

def run_strax(run_id, input_dir, target, n_readout_threads, compressor,
              run_start_time, samples_per_record, process_mode, debug=False):
    # Clear the swap memory used by npshmmex
    npshmex.shm_clear()
    # double check by forcefully clearing shm
    clear_shm()

    if debug:
        log.setLevel(logging.DEBUG)
    try:
        log.info(f"Starting strax to make {run_id} with input dir {input_dir}")

        st = new_context(**process_mode)

        # Make a function for running strax, call the function to process the run
        # This way, it can also be run inside a wrapper to profile strax
        def st_make():
            '''Run strax'''
            strax_config = dict(daq_input_dir=input_dir,
                                daq_compressor=compressor,
                                run_start_time=run_start_time,
                                record_length=samples_per_record,
                                n_readout_threads=n_readout_threads,
                                )
            st.make(run_id, target,
                    config=strax_config,
                    max_workers=process_mode['cores'])

        if args.profile.lower() == 'false':
            st_make()
        else:
            prof_file = f'run{run_id}_{args.profile}'
            if not '.prof' in prof_file:
                prof_file += '.prof'
            print(f'starting with profiler, saving as {prof_file}')
            with strax.profile_threaded(prof_file):
                st_make()
    except Exception as e:
        # Write exception to file, so bootstrax can read it
        exc_info = strax.formatted_exception()
        with open(exception_tempfile, mode='w') as f:
            f.write(exc_info)
        raise


def process_run(rd, send_heartbeats=True):
    log.info(f"Starting processing of run {rd['number']}")
    if rd is None:
        raise RuntimeError("Pass a valid rundoc, not None!")

    # Shortcuts for failing
    class RunFailed(Exception):
        pass

    def fail(reason):
        fail_run(rd, reason)
        raise RunFailed

    try:

        try:
            run_id = '%06d' % rd['number']
        except Exception as e:
            fail(f"Could not format run number: {str(e)}")

        for dd in rd['data']:
            if 'type' not in dd:
                fail("Corrupted data doc, found entry without 'type' field")
            if dd['type'] == 'live':
                break
            else:
                fail("Non-live data already registered; untracked failure?")
        else:
            fail(f"No live data entry in rundoc")

        if not osp.exists(dd['location']):
            fail(f"No access to live data folder {dd['location']}")

        thread_info = rd.get('ini', dict()).get('processing_threads', dict())
        n_readout_threads = sum([v for v in thread_info.values()])

        if not n_readout_threads:
            fail(f"Run doc for {run_id} has no readout thread count info")

        loc = osp.join(dd['location'], run_id)
        if not osp.exists(loc):
            fail(f"No live data at claimed location {loc}")

        # Remove any previous processed data
        # If we do not do this, strax will just load this instead of
        # starting a new processing
        clean_run(mongo_id=rd['_id'])

        # Remove any temporary exception info from previous runs
        if osp.exists(exception_tempfile):
            os.remove(exception_tempfile)

        target = args.target
        if rd['bootstrax'].get('n_failures', 0) > 1 and not args.process:
            # Failed before, and on autopilot: do just raw_records
            target = 'raw_records'

        compressor = get_compressor(rd)

        try:
            run_start_time = rd['start'].timestamp()
        except Exception as e:
            fail(f"Could not find start in datetime.datetime object: {str(e)}")

        try:
            samples_per_record = int(rd['ini']['strax_fragment_payload_bytes'] / 2)  # note that value in rd in bytes
            if not samples_per_record == 110:
                print(f'Samples_per_record = {samples_per_record}')
        except Exception as e:
            fail(f"Could not find ini.strax_fragment_payload_bytes in database: {str(e)}")

        if args.infer_run_mode:
            process_mode = infer_run_mode(rd, loc)
        else:
            process_mode = dict(cores=args.cores, max_messages=args.max_messages)

        strax_proc = multiprocessing.Process(
            target=run_strax,
            args=(run_id, loc, target, n_readout_threads, compressor,
                  run_start_time, samples_per_record, process_mode, args.debug))

        t0 = now()
        info = dict(started_processing=t0)
        strax_proc.start()
        pid_process = Process(strax_proc.pid)
        # Ignore first time of psutil.cpu_percent (returns 0), see documentation
        pid_process.cpu_percent()
        cpu_percent()

        while True:
            if send_heartbeats:
                update_usage(pid_process=pid_process, run_number=run_id)
                send_heartbeat()

            ec = strax_proc.exitcode
            if ec is None:
                if t0 < now(-timeouts['max_processing_time']):
                    fail(f"Processing took longer than {timeouts['max_processing_time']} sec")
                    kill_process(strax_proc.pid)

                # Still working, check in later
                # TODO: is there a good way to detect hangs, before max_processing_time expires?
                log.info(f"Still processing run {run_id}")
                set_run_state(rd, 'busy', **info)
                time.sleep(timeouts['check_on_strax'])
                continue

            elif ec == 0:
                log.info(f"Strax done on run {run_id}, performing basic data quality check")

                try:
                    md = st.get_meta(run_id, 'raw_records')
                except Exception:
                    fail("Processing succeeded, but metadata not readable: "
                         + traceback.format_exc())
                if not len(md['chunks']):
                    fail("Processing succeeded, but no chunks were written!")

                rd = get_run(mongo_id=rd['_id'])
                if not 'end' in rd:
                    fail("Processing succeeded, but run hasn't yet ended!")

                # Check that the data written covers the run
                # (at least up to some fudge factor)
                # Since chunks can be empty, and we don't want to crash,
                # this has to be done with some care...
                t_covered = timedelta(
                    seconds=(max([x.get('last_endtime', 0) for x in md['chunks']]) -
                             min([x.get('first_time', float('inf')) for x in md['chunks']])) / 1e9)
                run_duration = rd['end'] - rd['start']
                if not (0 < t_covered.seconds < float('inf')):
                    fail(f"Processed data covers {t_covered} sec")
                if not (timedelta(seconds=-30)
                        < (run_duration - t_covered)
                        < timedelta(seconds=30)):
                    fail(f"Processing covered {t_covered.seconds}, "
                         f"but run lasted {run_duration.seconds}!")

                log.info(f"Run {run_id} processed succesfully")
                set_run_state(rd, 'done', **info)
                if delete_live:
                    delete_live_data(loc)
                break

            else:
                # This is just the info that we're starting
                # exception retrieval. The actual error comes later.
                log.info(f"Failure while procesing run {run_id}")
                if osp.exists(exception_tempfile):
                    with open(exception_tempfile, mode='r') as f:
                        exc_info = f.read()
                    if not exc_info:
                        exc_info = '[No exception info known, exception file was empty?!]'
                else:
                    exc_info = "[No exception info known, exception file not found?!]"
                fail(f"Strax exited with exit code {ec}. Exception info: {exc_info}")

            # TODO: Strax should update run db with metadata
            # currently it doesn't, even if processing succeeds...

    except RunFailed:
        return


##
# Cleanup
##


def clean_run(*, mongo_id=None, number=None):
    """Removes all data on this host associated with a run
    that was previously registered in the run db.

    Does NOT remove temporary folders,
    nor data that isn't registered to the run db.
    """
    # We need to get the full data docs here, since I was too lazy to write
    # a surgical update below
    rd = get_run(mongo_id=mongo_id, number=number, full_doc=True)
    new_data = []
    for ddoc in rd['data']:
        if 'host' in ddoc and ddoc['host'] == hostname:
            loc = ddoc['location']
            if os.path.exists(loc):
                shutil.rmtree(loc)
        else:
            new_data.append(ddoc)
    run_coll.find_one_and_update(
        {'_id': rd['_id']},
        {'$set': {
            'data': new_data}})


def cleanup_db():
    """Find various pathological runs and clean them from the db

    Also cleans the bootstrax collection for stale entries
    """
    log.info("Checking for bad stuff in database")

    # Bootstrax instances that say they are active, but haven't reported in for a while
    # (these are on other hosts, or we would have killed them already)
    while True:
        send_heartbeat()
        bd = bs_coll.find_one_and_delete(
            {'time':
                 {'$lt': now(-timeouts['bootstrax_presumed_dead'])}})
        if bd is None:
            break
        log_warning(f"Bootstrax on host {bd['host']} presumed dead. Rest in peace")

    # Runs that say they are 'considering' or 'busy' but nothing happened for a while
    for state, timeout in [
        ('considering', timeouts['max_considering_time']),
        ('busy', timeouts['max_busy_time'])]:
        while True:
            send_heartbeat()
            rd = consider_run(
                {'bootstrax.state': state,
                 'bootstrax.time': {'$lt': now(-timeout)}},
                return_new_doc=False)
            if rd is None:
                break
            fail_run(rd,
                     f"Host {rd['bootstrax']['host']} said it was {state} "
                     f"at {rd['bootstrax']['time']}, but then didn't get further; "
                     f"perhaps it crashed on this run or is still stuck?")

    # Runs for which, based on the run doc alone, we can tell they are in a bad state
    # Mark them as failed.
    failure_queries = [
        ({'bootstrax.state': 'done',
          'end': {
              '$exists': False}},
         'Bootstrax state was done, but run did not yet end'),

        ({'bootstrax.state': 'done',
          'data': {
              '$not': {
                  '$elemMatch': {
                      "type": {
                          '$ne': 'live'}}}}},
         'Bootstrax state was done, but no processed data registered'),

        #       Can't add this yet, since registering happens when processing starts at the moment...
        #         ({'$not': {
        #             'bootstrax.state': 'done'},
        #           'data': {
        #               '$not': {
        #                   '$elemMatch': {
        #                       "type": {
        #                           '$ne': 'live'}}}}},
        #          'Bootstrax state was NOT done, but live data has been registered'),

        # For some reason this one doesn't work... probably I have my mongo query syntax confused
        #         ({'$and': [
        #             {'bootstrax.state': {
        #                 '$exists': True}},
        #             {'bootstrax.state': {
        #                 '$nin': 'considering done failed abandoned'.split()}}]},
        #          'Bootstrax state set to unrecognized state {bootstrax[state]}')
    ]

    for query, failure_message in failure_queries:
        while True:
            send_heartbeat()
            rd = consider_run(query)
            if rd is None:
                break
            fail_run(rd, failure_message.format(**rd))

    # Find any duplicated run numbers
    nos = [x['number']
           for x in run_coll.find({}, projection=['number'])]
    duplicate_numbers = [number
                         for number, n_occ in Counter(nos).items()
                         if n_occ > 1]

    # Abandon runs which we already know are so bad that
    # there is no point in retrying them
    abandon_queries = [
        ({'tags': {
            '$elemMatch': {
                'name': 'bad'}}},
         "Run has a 'bad' tag"),

        ({'ini.processing_threads': {
            '$exists': False}},
         "Old run doc format without processing thread info"),

        ({'number': {'$in': duplicate_numbers}},
         "Run number is not unique")]

    for query, failure_message in abandon_queries:
        query['bootstrax.state'] = {'$ne': 'abandoned'}
        failure_message += ' -- run has been abandoned'
        while True:
            send_heartbeat()
            rd = consider_run(query)
            if rd is None:
                break
            fail_run(rd, failure_message.format(**rd))
            abandon(mongo_id=rd['_id'])


if __name__ == '__main__':
    main()
