#!/usr/bin/env python

# Created on Mon Jul 29 17:35:18 2019
# Author: XiaoTao Wang

## Required modules

import argparse, sys, os, stripecaller
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

currentVersion = stripecaller.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Visualize stripe calls on heatmap.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')
    
    # Output
    parser.add_argument('-O', '--output', help='Output png file name.')
    parser.add_argument('--dpi', default=200, type=int,
                        help='''The resolution in dots per inch of the output figure.''')

    # Input
    parser.add_argument('-p', '--path',
                        help = 'Cooler URI.')
    parser.add_argument('-I', '--stripe-file', help='Stripe file in bedpe format.')
    parser.add_argument('-C', '--chrom', help='Chromosome label of your anticipated region.')
    parser.add_argument('-S', '--start', type=int, help='Start site (bp) of the region.')
    parser.add_argument('-E', '--end', type=int, help='End site (bp) of the region.')
    parser.add_argument('--skip-rows', default=0, type=int,
                        help='''Number of leading lines in the stripe file to skip.''')
    parser.add_argument('--correct', action='store_true',
                        help='''Whether or not plot ICE-corrected heatmap.''')
    parser.add_argument('--triangle', action='store_true',
                        help='''If specified, heatmap will be drawn in triangle shape.''')
    parser.add_argument('--vmin', type=float,
                        help='''The minimum value that the colorbar covers.''')
    parser.add_argument('--vmax', type=float,
                        help='''The maximum value that the colorbar covers.''')
    parser.add_argument('--tick-num', type=int, default=4,
                        help='''Number of ticks for genomic coordinates''')
    parser.add_argument('--nolabel', action='store_true',
                        help='''Whether or not add genomic coordinates.''')
    parser.add_argument('--label-size', type=float, default=15, help='''Font size of the coordinate text.''')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def properU(pos):
    
    i_part = int(pos) // 1000000 # Integer Part
    d_part = (int(pos) % 1000000) // 1000 # Decimal Part
    
    if (i_part > 0) and (d_part > 0):
        return ''.join([str(i_part), 'M', str(d_part), 'K'])
    elif (i_part == 0):
        return ''.join([str(d_part), 'K'])
    else:
        return ''.join([str(i_part), 'M'])

def caxis_H(ax):
    """
    Axis Control for HeatMaps.
    """
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    ax.tick_params(axis = 'both', labelsize = 12, length = 5, pad = 7)

def find_chrom_pre(chromlabels):

    ini = chromlabels[0]
    if ini.startswith('chr'):
        return 'chr'
    
    else:
        return ''
 
def _parse_peakfile(filpath, skip=1):
    """
    Generate a stripe annotation table.
    """
    D = {}
    with open(filpath, 'r') as source:
        for i, line in enumerate(source):
            if i < skip:
                continue
            parse = line.rstrip().split()
            chrom = parse[0]
            info = (int(parse[1]), int(parse[2]), int(parse[4]), int(parse[5]))
            if chrom in D:
                D[chrom].append(info)
            else:
                D[chrom] = [info]
    
    # consistent chrom labels
    keys = list(D.keys())
    pre = find_chrom_pre(keys)
    new = {}
    for chrom in D:
        k = chrom.lstrip(pre)
        new[k] = D[chrom]

    return new

def Triangle(M, stripe_boxes, chrom, ticks, labels, cmap, vmin=None, vmax=None, label_size=9, nolabel=False):

    import itertools
    
    fig = plt.figure(figsize = (11,7))
    h_ax, c_ax = fig.add_axes([0.15, 0, 0.7, 1.08]), fig.add_axes([0.15, 0.72, 0.02, 0.13])
    
    n = M.shape[0]
    # Create the rotation matrix
    t = np.array([[1,0.5], [-1,0.5]])
    A = np.dot(np.array([(i[1],i[0]) for i in itertools.product(range(n,-1,-1),range(0,n+1,1))]),t)

    # Plot the Heatmap ...
    x = A[:,1].reshape(n+1, n+1)
    y = A[:,0].reshape(n+1, n+1)
    if vmax is None:
        vmax = np.percentile(M[M.nonzero()], 95)
    if vmin is None:
        vmin = M.min()
    sc = h_ax.pcolormesh(x, y, np.flipud(M), vmin=vmin, vmax=vmax, cmap=cmap,
                        edgecolors='none')
    
    # Plot the stripes
    for Xi, Yi in stripe_boxes:
        xi = list(map(int, [i+0.5 for i in Xi]))
        yi = list(map(int, [n-j-1.5 for j in Yi]))
        lx = x[:-1, :-1][yi, xi] - 0.5
        ly = y[:-1, :-1][yi, xi] + 1
        h_ax.fill(lx, ly, facecolor = 'none', edgecolor='k')
    
    h_ax.set_xlim(x.min(), x.max())
    h_ax.set_ylim(y.min(), y.max())
    
    # chromosome arrowbar
    if not nolabel:
        ax = fig.add_axes([0.15, 0.525, 0.7, 0.01])
        ax.tick_params(axis='both', bottom=True, top=False, left=False, right=False,
                       labelbottom=True, labeltop=False, labelleft=False, labelright=False)
        ax.set_xticks(ticks)
        if len(ticks) < 7:
            ax.set_xticklabels(labels, fontsize=label_size)
        else:
            ax.set_xticklabels(labels, fontsize=label_size, rotation=15, ha='right')
        ax.set_xlabel('chr'+chrom.lstrip('chr'), fontsize=label_size+2)
    
    # Hide the bottom part
    xmin = A[:,1].min()
    xmax = A[:,1].max()
    ymin = A[:,0].min()
    ymax = 0
    h_ax.fill([xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax], 'w', ec='none')
    h_ax.axis('off')
    
    # Colorbar
    fig.colorbar(sc, cax=c_ax, ticks=[0,vmax], format='%.2g')

    return fig

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:

        import cooler
        from matplotlib.colors import LinearSegmentedColormap
        
        cmap = LinearSegmentedColormap.from_list('interaction',
                                                 ['#FFFFFF','#FFDFDF','#FF7575','#FF2626','#F70000'])

        # Load Cooler
        Lib = cooler.Cooler(args.path)
        chrom, start, end = args.chrom, args.start, args.end

        # Extract matrix
        res = Lib.binsize
        start = start//res * res
        end = end//res * res
        M = Lib.matrix(balance=args.correct, sparse=False).fetch((chrom,start,end))
        M[np.isnan(M)] = 0
    
        nonzero = M[np.nonzero(M)]
        if args.vmin is None:
            vmin = nonzero.min()
        else:
            vmin = args.vmin
        if args.vmax is None:
            vmax = np.percentile(nonzero, 95)
        else:
            vmax = args.vmax

        chrom = chrom.lstrip('chr')
        Len = M.shape[0]
        stripe_boxes = []
        if not args.stripe_file is None:
            stripe_file = args.stripe_file
            # Read stripe data
            stripes = _parse_peakfile(stripe_file, skip=args.skip_rows)
            stripes = stripes[chrom]
    
            # generate a rectangle for each stripe
            for xs, xe, ys, ye in stripes:
                xs_i = xs//res - start//res
                xe_i = xe//res - start//res
                ys_i = ys//res - start//res
                ye_i = ye//res - start//res
                if xe_i - xs_i == 1:
                    if (xs_i < 0) or (ys_i + 1 >= Len):
                        continue
                    ye_i = min(ye_i, Len-1)

                if ye_i - ys_i == 1:
                    if (xe_i < 2) or (ys_i + 1 > Len):
                        continue
                    xs_i = max(0, xs_i)
                
                Xi = [ys_i-0.5, ye_i-0.5, ye_i-0.5, ys_i-0.5]
                Yi = [xs_i-0.5, xs_i-0.5, xe_i-0.5, xe_i-0.5]
                stripe_boxes.append([Xi, Yi])
        
        interval = (end - start) // res
        ticks = list(np.linspace(0, interval, args.tick_num).astype(int))
        pos = list(np.linspace(start, end, args.tick_num).astype(int))
        labels = [properU(p) for p in pos]

        ## draw regular heatmap
        if not args.triangle:
            size = (8, 7.3)
            width = 0.7; Left = 0.1
            HB = 0.1; HH = width * size[0] / size[1]

            fig = plt.figure(figsize=size)
            ax = fig.add_axes([Left, HB, width, HH])
            sc = ax.imshow(M, cmap = cmap, aspect = 'auto', interpolation = 'none',
                           vmax = vmax, vmin = vmin)
            # drawing stripes
            for Xi, Yi in stripe_boxes:
                ax.fill(Xi, Yi, facecolor = 'none', edgecolor='k')

            if args.nolabel:
                ax.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                               labelbottom=False, labeltop=False, labelleft=False, labelright=False)
            else:
                ax.set_xticks(ticks)
                ax.set_xticklabels(labels, fontsize=args.label_size)
                ax.set_yticks(ticks)
                ax.set_yticklabels(labels, fontsize=args.label_size)
                ax.set_xlabel('chr'+chrom.lstrip('chr'), fontsize=args.label_size+2)
                ax.set_ylabel('chr'+chrom.lstrip('chr'), fontsize=args.label_size+2)
                caxis_H(ax)

            ## Colorbar
            ax = fig.add_axes([Left+width+0.03, HB, 0.03, HH])
            fig.colorbar(sc, cax=ax)
        else:
            fig = Triangle(M, stripe_boxes, chrom, ticks, labels, cmap, vmin=vmin, vmax=vmax,
                           label_size=args.label_size, nolabel=args.nolabel)

        plt.savefig(args.output, bbox_inches='tight', dpi=args.dpi)
        plt.close()


if __name__ == '__main__':
    run()