#!/usr/bin/env python

# Created on Fri Jun 29 14:22:41 2018
# Author: XiaoTao Wang

## Required modules

import argparse, sys, os, logging, logging.handlers, traceback, neoloop

currentVersion = neoloop.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Identify novel loop interactions across SV
                                     points.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    # Output
    parser.add_argument('-O', '--output', help='Output path.')

    # Input
    parser.add_argument('-H', '--hic', nargs='+', help='''List of cooler URIs.''')
    parser.add_argument('--assembly', help='''The assembled SV list outputed by assemble-complexSVs.''')
    
    # Algorithm
    parser.add_argument('-R', '--region-size', default=5000000, type=int,
                        help = '''The extended genomic span of SV break points.(bp)''')
    parser.add_argument('--balance-type', default='CNV', choices=['CNV', 'ICE'],
                        help = 'Normalization method.')
    parser.add_argument('--protocol', default='insitu', choices=['insitu', 'dilution'],
                        help='''Experimental type of your contact matrices. insitu: insitu Hi-C; dilution: dilution Hi-C''')
    parser.add_argument('--prob', type=float, default=0.9,
                        help = 'Probability threshold.')
    parser.add_argument('--no-clustering', action = 'store_true',
                        help = 'No pooling will be performed if specified.')
    parser.add_argument('--min-marginal-peaks', type = int, default = 1,
                        help = '''Minimum marginal number of loops when detecting loop anchors.''')
    parser.add_argument('--nproc', default=1, type=int, help='Number of worker processes.')
    parser.add_argument('--logFile', default = 'neoloop.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes=100000,
                                                           backupCount=5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Output Path = {0}'.format(args.output),
                   '# SV assembly = {0}'.format(args.assembly),
                   '# Cooler List = {0}'.format(args.hic),
                   '# Extended Genomic Span = {0}bp'.format(args.region_size),
                   '# Balance Type = {0}'.format(args.balance_type),
                   '# Experimental protocol = {0}'.format(args.protocol),
                   '# Probability threshold = {0}'.format(args.prob),
                   '# No pooling = {0}'.format(args.no_clustering),
                   '# Minimum marginal peaks = {0}'.format(args.min_marginal_peaks),
                   '# Number of Processes = {0}'.format(args.nproc),
                   '# Log file name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        from joblib import Parallel, delayed
        from neoloop.callers import combine_annotations
        import cooler
        
        rsize = args.region_size
        protocol, prob, mmp = args.protocol, args.prob, args.min_marginal_peaks
        no_pool = args.no_clustering
        if args.balance_type == 'CNV':
            balance = 'sweight'
        elif args.balance_type == 'ICE':
            balance = 'weight'
        else:
            balance = False

        try:
            # load Hi-C matrix
            cools = {}
            for path in args.hic:
                lib = cooler.Cooler(path)
                res = lib.binsize
                cools[res] = path
            # load structural variations
            logger.info('Load assembled SVs')
            lines = {}
            with open(args.assembly, 'r') as source:
                for line in source:
                    parse = line.rstrip().split()
                    lines[parse[0]] = '\t'.join(parse[1:])

            logger.info('Predict loops ...')
            byres = {}
            for res in sorted(cools, reverse=True):
                logger.info('resolution: {0} ...'.format(res))
                clr = cooler.Cooler(cools[res])
                Params = []
                for k in lines:
                    Params.append((clr, k, lines, rsize, balance, protocol, prob, mmp, no_pool))

                results = Parallel(n_jobs=args.nproc, verbose=10)(delayed(pipeline)(*i) for i in Params)
                
                logger.info('Merge loops across assemblies ...')
                cache = {}
                for tmp in results:
                    if tmp is None:
                        continue
                    for l in tmp:
                        if not l in cache:
                            cache[l] = []
                        cache[l].append(tmp[l])
                byres[res] = cache
            
            if len(cools) > 1:
                logger.info('Combine loops from multiple resolutions')
            loop_list = combine_annotations(byres)
            logger.info('Output ...')
            with open(args.output, 'w') as out:
                for line in loop_list:
                    out.write('\t'.join(line)+'\n')

            logging.info('Done')
        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)
    
def pipeline(clr, index, assemblies, rsize, balance, protocol, prob, mmp, nopool):

    from neoloop.callers import Peakachu
    from neoloop.assembly import complexSV
    import numpy as np

    fu = complexSV(clr, assemblies[index], span=rsize, col=balance, protocol=protocol)
    fu.reorganize_matrix()
    if len(fu.Map) != fu.fusion_matrix.shape[0]:
        return # invalid assembly
        
    fu.correct_heterozygosity()
    fu.remove_gaps()

    core = Peakachu(fu.gap_free, lower=6*fu.res, upper=3000000, res=fu.res, protocol=protocol)
    loop_index = core.predict(thre=prob, no_pool=nopool, min_count=mmp, index_map=fu.index_map,
                            chains=fu.chains)
        
    loops = fu.index_to_coordinates(loop_index)
    final_dict = {}
    for k in loops:
        final_dict[k] = (index, loops[k][0], loops[k][1])

    return final_dict
    

if __name__ == '__main__':
    run()