#!/usr/bin/env python

# Created on Fri Jun 29 19:26:12 2018
# Author: XiaoTao Wang

## Required modules

import argparse, sys, os, stripecaller

currentVersion = stripecaller.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Perform the pile-up analysis on
                                     a stripe list.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')
    
    # Output
    parser.add_argument('-O', '--output', help='Output png file name.')
    parser.add_argument('--dpi', default=300, type=int,
                        help='''The resolution in dots per inch of the output figure.''')
    
    # Input
    parser.add_argument('-p', '--path',
                        help = 'URI string pointing to a cooler under specific resolution.')
    parser.add_argument('-S', '--stripe-file', help='Path to the stripe in bedpe format.')
    parser.add_argument('-U', '--useICE', action='store_true',
                        help='''Whether or not use ICE-corrected matrix.''')
    parser.add_argument('-M', '--min-dis', default=100000, type=int,
                        help='''Stripes will be trimed in order each dot of a stripe is separated
                        by less than this genomic distance.''')
    parser.add_argument('--min-stripe-len', default=9, type=int,
                        help='''Any stripes less than this size will be removed from the analysis.''')
    parser.add_argument('--stripe-type', default='horizontal', choices=['horizontal', 'vertical'],
                        help='''Type of stripes to be piled up.''')
    parser.add_argument('--vmax', type=float,
                        help='''The maximum value that the colorbar covers.''')
    parser.add_argument('--no-pre', action = 'store_true', help = '''Whether your chromosome
                         labels have the "chr" prefix or not.''')
    parser.add_argument('--colormap-name', default='Reds',
                        help='Name of the colormap in matplotlib.')
        
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def load_stripes(fil_path, res, label='horizontal', nopre=True):

    D = {}
    with open(fil_path, 'r') as source:
        for line in source:
            parse = line.rstrip().split()
            chrom, s1, e1, s2, e2 = parse[0], int(parse[1])//res, int(parse[2])//res, int(parse[4])//res, int(parse[5])//res
            if nopre:
                chrom = chrom.lstrip('chr')
            if not chrom in D:
                D[chrom] = []
            if label=='horizontal':
                if (e1-s1) < (e2-s2):
                    D[chrom].append((s1, [s2, e2]))
            elif label=='vertical':
                if (e1-s1) > (e2-s2):
                    D[chrom].append((s2, [s1, e1]))
    
    return D

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:

        import numpy as np
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
        import cooler
        from matplotlib.colors import LinearSegmentedColormap

        cmap = LinearSegmentedColormap.from_list('interaction',
                ['#FFFFFF','#ff9292','#ff6767','#F70000'])

        ## extract Hi-C matrix
        hic_pool = cooler.Cooler(args.path)
        res = hic_pool.binsize
        correct = args.useICE
        min_dis = args.min_dis // res
        min_len = args.min_stripe_len
        width = 6

        stripes = load_stripes(args.stripe_file, res, label=args.stripe_type, nopre=args.no_pre)
        agg = []
        for c in stripes:
            if not c in hic_pool.chromsizes:
                continue
            M = hic_pool.matrix(balance=correct, sparse=True).fetch(c)
            M = M.tocsr()

            for p in stripes[c]:
                ls = p[1][0]
                le = p[1][1]
                if args.stripe_type=='horizontal':
                    if ls - p[0] < min_dis:
                        ls = p[0] + min_dis
                    if le - ls < min_len:
                        continue
                    if p[0] < width: # local matrix width
                        continue
                    sub = M[p[0]-width:p[0]+width+1, ls:ls+2*width+1].toarray()
                if args.stripe_type=='vertical':
                    if p[0] - le + 1 < min_dis:
                        le = p[0] - min_dis + 1
                    if le - ls < min_len:
                        continue
                    if le - 2*width - 1 < 0:
                        continue
                    sub = M[le-2*width-1:le, p[0]-width:p[0]+width+1].toarray()
                
                if sub.shape[0]!=(2*width+1) or sub.shape[1]!=(2*width+1):
                    continue
                if sub.mean() == 0:
                    continue
                mask = np.isnan(sub)
                if mask.sum() > 0:
                    continue
                sub = sub / sub.mean()
                agg.append(sub)
        
        agg = np.r_[agg]
        avg = agg.mean(axis=0)
        if args.stripe_type=='horizontal':
            background = np.vstack((avg[:width,:], avg[width+1:,:]))
            score = (avg[width] / background.mean()).mean()
        else:
            background = np.hstack((avg[:,:width], avg[:,width+1:]))
            score = (avg[:,width] / background.mean()).mean()
        maxi = background.mean() * 2
        
        if args.vmax is None:
            vmax = maxi
        else:
            vmax = args.vmax
        
        if args.colormap_name=='traditional':
            plt.imshow(avg, cmap=cmap, vmax=vmax, interpolation='none')
        else:
            plt.imshow(avg, cmap=args.colormap_name, vmax=vmax, interpolation='none')

        plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                        labelbottom=False, labeltop=False, labelleft=False, labelright=False)
        plt.title('Enrichment Score = {0:.3g}'.format(score), fontsize=21)
        if args.stripe_type=='horizontal':
            plt.ylabel('Anchor', fontsize=21)
        else:
            plt.xlabel('Anchor', fontsize=21)
        plt.colorbar()
        plt.savefig(args.output, dpi=args.dpi, bbox_inches='tight')
        plt.close()

if __name__ == '__main__':
    run()



