#!/usr/bin/env python

import argparse
from ast import literal_eval
import json
import os
import os.path as osp

import numpy as np
from tqdm import tqdm

import strax
import straxen


def main():
    parser = argparse.ArgumentParser(
        description='Refresh strax raw_records created with < v0.9.0')
    parser.add_argument(
        '--parent_folder',
        default='.',
        help="Folder with raw_records folders.")
    parser.add_argument(
        '--no_run_metadata',
        action='store_true',
        help='Refresh even if you have lost the run metadata. '
             'Some useful sanity checks will be disabled.')
    parser.add_argument(
        '--procrustean',
        action='store_true',
        help='Delete records that fall across chunk boundaries. '
             'Use as a last resort for triggerless data you do not want '
             'to reconvert from pax properly.')
    parser.add_argument(
        '--xnt',
        action='store_true',
        help="Set if records came from the DAQReader, not RecordsFromPax")
    parser.add_argument(
        'run_id',
        help="run_id to convert")
    args = parser.parse_args()
    run_id = args.run_id
    parent_folder = args.parent_folder

    ##
    # Prepare context
    ##
    st = strax.Context(
        storage=[strax.DataDirectory(
            parent_folder,
            # We WILL overwrite your data
            # just not through the usual means:
            readonly=True)],
        **straxen.contexts.common_opts)
    if args.xnt:
        st.register(straxen.DAQReader)
    else:
        st.register(straxen.RecordsFromPax)

    ##
    # Get metadata
    ##
    folder = st.storage[0].find(st.key_for(run_id, 'raw_records'),
                                fuzzy_for='raw_records')[1]
    md = st.get_metadata(run_id, 'raw_records')
    metadata_fn = os.path.join(
        folder,
        strax.dirname_to_prefix(folder) + '-metadata.json')
    assert osp.exists(metadata_fn)
    dtype = np.dtype(literal_eval(md['dtype']))
    record_length = strax.record_length_from_dtype(dtype)

    if not args.no_run_metadata:
        run_md = st.run_metadata(run_id)
        run_start, run_end = [
            int(x.timestamp()) * int(1e9)
            for x in [run_md['start'], run_md['end']]]
    else:
        run_start, run_end = None, None

    if not len(md['chunks']):
        raise ValueError("Cannot convert data: no chunks!")
    if 'start' in md['chunks'][0]:
        raise ValueError("This data was already converted")

    ## 
    # Convert data
    ##
    last_endtime = 0
    for i, c in enumerate(tqdm(md['chunks'], 
                               desc=f'Converting raw_records for {run_id}')):
        filename = osp.join(folder, c['filename'])
        rr = strax.load_file(
            filename,
            dtype=dtype, 
            compressor=md['compressor'])

        if not len(rr):
            raise ValueError("Cannot convert data with empty chunks")

        if rr[0]['time'] < last_endtime:
            if args.procrustean:
                to_cut = rr['time'] < last_endtime
                print(f"[!!] Removing {to_cut.sum()} records "
                      f"from chunk {i} to remove overlaps!")
                rr = rr[~to_cut]
            else:
                raise ValueError(
                    f"Cannot convert data: chunk {i}'s data starts "
                    f"at {rr[0]['time']} while the previous chunk's data "
                    f"ended at {last_endtime}")

        if i == 0:
            if run_start is not None:
                c['start'] = run_start
            else:
                c['start'] = rr[0]['time']
        else:
            c['start'] = last_endtime
        c['end'] = last_endtime = strax.endtime(rr).max()

        new_rr = np.zeros(len(rr), dtype=strax.raw_record_dtype(record_length))
        strax.copy_raw_records(rr, new_rr)
        if 'baseline' in rr.dtype.fields:
            # Undo baselining
            new_rr['data'] = rr['baseline'][:, np.newaxis] - rr['data']
            strax.zero_out_of_bounds(new_rr)

        c['run_id'] = run_id
        c['nbytes'] = new_rr.nbytes
        c['filesize'] = strax.save_file(
            filename,
            new_rr,
            compressor=md['compressor'])
        # We must rewrite these too, the chunk count could have changed
        # if args.procrustean
        c['n'] = len(rr)
        c['first_time'] = rr[0]['time']
        c['first_endtime'] = strax.endtime(rr[0])
        c['last_time'] = rr[-1]['time']
        c['last_endtime'] = strax.endtime(rr[-1])

    if run_start is None:
        run_start = md['chunks'][0]['start']
        run_end = md['chunks'][-1]['start']

    ##
    # Set and write out new metadata
    ##
    md['start'] = run_start
    md['end'] = run_end
    md['run_id'] = run_id
    md['data_kind'] = 'raw_records'
    md['converted_from_old_strax'] = md['strax_version']
    md['strax_version'] = strax.__version__
    md['dtype'] = np.dtype(strax.raw_record_dtype(record_length)).descr.__repr__()

    with open(metadata_fn, mode='w') as f:
        f.write(json.dumps(md, 
                           sort_keys=True, 
                           indent=4, 
                           cls=strax.NumpyJSONEncoder))


if __name__ == '__main__':
    main()
