#!/usr/bin/env python
import os

import click

@click.group()
def cli():
    pass
    
@click.command(name='check_input')
@click.option('-i', '--input_path', required=True, help='path to text file containing names of input samples')
@click.option('-f', '--formats', default=None, help='show derived samples with these derivation formats (separated by commas)')
def check_input(input_path):
    if formats is not None:
        formats = [fmt.strip() for fmt in formats.split(",") if fmt.strip()]
    from atlas_derivation.utils.did_utils import check_dids_available, get_daods_from_aods
    samples = check_dids_available(input_path)
    daod_samples = get_daods_from_aods(samples, formats=formats)
    print("----------------------------------------- DAOD INPUT CHECK -----------------------------------------")
    print("INFO: Avaliable Samples:")
    print("\n".join(samples['available']))
    print("")
    print("INFO: Missing Samples:")
    print("\n".join(samples['missing']))
    print("")
    print("INFO: The following samples have corresponding derived samples:")
    for AOD, DAODs in daod_samples.items():
        print(f"{AOD}:")
        print("\n".join([f"\t{d}" for d in DAOD]))
    print("----------------------------------------------------------------------------------------------------")
    
    
@click.command(name='list_samples')
@click.option('-n', '--name', required=True, show_default=True,
              help='Expression used to filter the DAODs. Wild card is accepted.')
@click.option('-t', '--did-type', default='container', show_default=True,
              help='Filter by DID type.')
@click.option('--not-empty/--allow-empty', default=True, show_default=True,
              help='Whether to filter empty DAODs.')
@click.option('--latest-ptag/--any-ptag', default=True, show_default=True,
              help='Wheter to show only the DAODs with the latest ptags.')
@click.option('--single-rtag/--allow-multi-rtag', default=True, show_default=True,
              help='Wheter to show only the DAODs with a single rtag.')
@click.option('--single-ptag/--allow-multi-ptag', default=True, show_default=True,
              help='Wheter to show only the DAODs with a single ptag.')
@click.option('--esrp-tags-only/--allow-any-tags', default=True, show_default=True,
              help='Wheter to show only the DAODs with e, s, r and p tags only.')
@click.option('--data-type', default="ALL", show_default=True,
              type=click.Choice(["DAOD", "NTUP_PILEUP", "ALL"], case_sensitive=False),
              help='Filter samples by data type. Choose from "DAOD", "NTUP_PILEUP", "".')
@click.option('--display-columns', default="name,nevent,size", show_default=True,
              help='Columns to display separated by commas. Common columns include "name", "type"'
              ', "ptag", "nevent", "size", etc.')
@click.option('--show-table/--show-text', default=True, show_default=True,
              help='Whether to display table or just plain text of dids.')
@click.option('--outname', '-o', default=None, show_default=True,
              help='If specified, save the output with the given filename.')
def list_samples(**kwargs):
    from atlas_derivation.utils.did_utils import list_dids
    from atlas_derivation.components import DIDCollection, DID
    dids = list_dids(kwargs['name'], kwargs['did_type'])
    display_cols = [col.strip() for col in kwargs['display_columns'].split(',') if col.strip()]
    metadata = ((kwargs['did_type'].lower() != 'all') or
                any(col in DID.METADATA_COLS for col in display_cols))
    filedata = any(col in DID.FILEDATA_COLS for col in display_cols)
    collection = DIDCollection(dids, metadata=metadata, filedata=filedata)
    collection.filter_derived_samples(single_rtag=kwargs['single_rtag'],
                                      single_ptag=kwargs['single_ptag'],
                                      latest_ptag=kwargs['latest_ptag'],
                                      esrp_tags_only=kwargs['esrp_tags_only'],
                                      did_type=kwargs['did_type'],
                                      data_type=kwargs['data_type'],
                                      not_empty=kwargs['not_empty'],
                                      inplace=True)
    if kwargs['show_table']:
        collection.print_table(attributes=display_cols)
    else:
        for did in collection.get_dids():
            print(did)
    outname = kwargs['outname']
    if outname is not None:
        export_format = os.path.splitext(outname)[1].strip(".")
        if export_format == 'csv':
            collection.df.to_csv(outname, index=False)
        elif export_format == 'json':
            data = collection.df.to_dict('records')
            import json
            json.dump(data, open(outname, 'w'), indent=2)
        else:
            dids = collection.get_dids()
            with open(outname, 'w') as f:
                for did in dids:
                    f.write(f"{did}\n")

if __name__ == "__main__":
    cli.add_command(list_samples)
    cli.add_command(check_input)
    cli()   