#!/usr/bin/env python

# Created on Fri Mar 22 14:33:09 2019
# Author: XiaoTao Wang

import argparse, sys, logging, logging.handlers, stripecaller

currentVersion = stripecaller.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(usage = '%(prog)s <-O output> [options]',
                                     description = 'Automated Stripe Identification from contact matrix.',
                                     formatter_class = argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    # Output
    parser.add_argument('-O', '--output', help = 'Output file path.')
    parser.add_argument('--logFile', default = 'stripe-caller.log', help = 'Logging file name.')
    
    group_1 = parser.add_argument_group(title = 'Data source related:')
    group_1.add_argument('-p', '--path',
                         help = 'Cooler URI.')
    group_1.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'List of chromosome labels. Only Hi-C data within the specified '
                        'chromosomes will be included. Specially, "#" stands for chromosomes '
                        'with numerical labels. "--chroms" with zero argument will include '
                        'all chromosome data.')
    group_1.add_argument('-I', '--loop-file', help = '''Path to a loop file in bedpe format.''')
    group_1.add_argument('-T', '--tad-file', help = '''Path to a TAD file in bed format.''')
    group_1.add_argument('--no-pre', action = 'store_true', help = '''Whether your chromosome
                         labels have the "chr" prefix or not.''')
    group_1.add_argument('--use-raw', action = 'store_true', help = '''If specified, use raw signals
                         instead of ICE to calculate the expected.''')
    
    ## About the algorithm
    group_2 = parser.add_argument_group(title = 'Algorithm Parameters:')
    group_2.add_argument('--local-num', type=int, default=2,
                         help='''Number of local surrounding bins.''')
    group_2.add_argument('--background-length', type=int, default=4,
                         help='''Length of vertical (for 3-prime stripes) and horizontal (for 5-prime
                         stripes) background lines.''')
    group_2.add_argument('--siglevel', type = float, default = 0.05, help = 'Significant Level.')
    group_2.add_argument('--fold-enrichment', type = float, default = 1.1,
                         help = '''Threshold of fold enrichment score.''')
    group_2.add_argument('--maxapart', type = int, default = 3000000, help = '''We only search bin paris
                         separated by this size of genomic distance.''')
    group_2.add_argument('--nproc', type = int, default = 1, help = 'Number of worker processes.')
    group_2.add_argument('--min-seed-len', type = int, default = 6,
                         help = '''Minimum length of seed stretches.''')
    group_2.add_argument('--max-gap', type = int, default = 5,
                         help = '''Maximum allowed gap size when merging two stretches together.''')
    group_2.add_argument('--min-stripe-len', type = int, default = 9,
                         help = '''Minimum length of output stripes''')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():
     # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '--help']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        # Logger Level
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes = 200000,
                                                           backupCount = 5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-21s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)

        logger.info('Python Version: {}'.format(sys.version.split()[0]))

         ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Output file = {}'.format(args.output),
                   '# Cooler URI = {}'.format(args.path),
                   '# Chromosomes = {}'.format(args.chroms),
                   '# Use Raw Signals = {}'.format(args.use_raw),
                   '# Number of local signal = {}'.format(args.local_num),
                   '# Length of background = {}'.format(args.background_length),
                   '# Significant Level = {}'.format(args.siglevel),
                   '# Maximum Genomic distance = {}'.format(args.maxapart),
                   '# Minimum Seed Size = {}'.format(args.min_seed_len),
                   '# Maximum Gap Size in Merging = {}'.format(args.max_gap),
                   '# Minimum Stripe Length = {}'.format(args.min_stripe_len),
                   '# Number of Processes = {}'.format(args.nproc)
                   ]
        
        argtxt = '\n'.join(arglist)
        logger.info('\n'+argtxt)
        
        l_n = args.local_num
        b_n = args.background_length
        fold = args.fold_enrichment
        min_seed_len = args.min_seed_len
        max_gap = args.max_gap
        min_stripe_len = args.min_stripe_len
        loop_fil = args.loop_file
        tad_fil = args.tad_file

        import cooler
        from multiprocess import Pool
        import numpy as np
        from scipy import sparse
        from stripecaller.callers import call_stripes, remove_compartment_stripes, \
            load_loops, remove_loop_pixels, load_TADs

        def worker(tuple_args):
            Lib, key, l_n, b_n, siglevel, fold, maxapart, res, min_seed_len, max_gap, min_stripe_len, loop_fil, tad_fil, use_raw = tuple_args
            H = Lib.matrix(balance=False, sparse=True).fetch(key)
            
            chromLen = H.shape[0]
            # Customize Sparse Matrix ...
            num = maxapart // res + l_n + b_n + 1
            valid_diag = np.arange(num)
            Diags = [H.diagonal(i) for i in valid_diag]
            M = sparse.diags(Diags, valid_diag, format='csr')

            del H

            if use_raw:
                cM = M
            else:
                cH = Lib.matrix(balance=True, sparse=True).fetch(key)
                cDiags = []
                for i in valid_diag:
                    diag = cH.diagonal(i)
                    mask = np.isnan(diag)
                    diag[mask] = 0
                    cDiags.append(diag)
                cM = sparse.diags(cDiags, valid_diag, format='csr')

                del cH

            #key = key.lstrip('chr')
            h_candi, v_candi = call_stripes(key, M, cM, maxapart, res, l_n, b_n, siglevel, fold, chromLen,
                                            min_seed_len=min_seed_len, max_gap=max_gap, min_stripe_len=min_stripe_len)

            del M, cM

            if not tad_fil is None:
                tads = load_TADs(tad_fil, res)
                h_candi = remove_compartment_stripes(h_candi, tads, key)
                v_candi = remove_compartment_stripes(v_candi, tads, key)

            if not loop_fil is None:
                loops = load_loops(loop_fil, res)
                h_stripes = remove_loop_pixels(h_candi, loops, key, min_stripe_len=min_stripe_len)
                v_stripes = remove_loop_pixels(v_candi, loops, key, min_stripe_len=min_stripe_len)
            else:
                h_stripes, v_stripes = h_candi, v_candi
                
            return key, h_stripes, v_stripes
        
        logger.info('Loading contact matrices from cool ...')
        Lib = cooler.Cooler(args.path)
        res = Lib.binsize
        
        OF = open(args.output, 'w')
        Params = []
        for key in Lib.chromnames:
            chromlabel = key.lstrip('chr')
            if ((not args.chroms) or (chromlabel.isdigit() and '#' in args.chroms) or (chromlabel in args.chroms)):
                Params.append((Lib, key, l_n, b_n, args.siglevel, fold, args.maxapart, res, min_seed_len,
                               max_gap, min_stripe_len, loop_fil, tad_fil, args.use_raw))
        
        if args.nproc == 1:
            results = map(worker, Params)
        else:
            pool = Pool(args.nproc)
            results = pool.map(worker, Params)
            pool.close()
            pool.join()
        
        for key, h_stripes, v_stripes in results:
            if args.no_pre:
                key = 'chr'+key
            for h in h_stripes:
                for d in h_stripes[h]:
                    tmp = [key, str(h*res), str(h*res+res), key, str(d[0]*res), str(d[1]*res)]
                    OF.write('\t'.join(tmp)+'\n')
            for v in v_stripes:
                for d in v_stripes[v]:
                    tmp = [key, str(d[0]*res), str(d[1]*res), key, str(v*res), str(v*res+res)]
                    OF.write('\t'.join(tmp)+'\n')

        OF.flush()
        OF.close()

        logger.info('Done!')

if __name__ == '__main__':
    run()

