#!/usr/bin/env python2

import sys, os, math, re
import pysam as ps
from pyfaidx import Fasta
from interval import interval
import mappy as mp
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.patches import Wedge
from matplotlib.patches import Arc
from matplotlib.patches import FancyArrow
from matplotlib.patches import FancyArrowPatch
from matplotlib.patches import Rectangle
import isocirc.utils as ut
import isocirc.parse_bam as pb

from isocirc.__init__ import isoform_output_header_idx
from isocirc.__init__ import __program__
from isocirc.__init__ import __version__

idx = isoform_output_header_idx

bed_header_ele = ['chrom', 'chromStart', 'chromEnd', 'name', 'score', 'strand',
              'thickStart', 'thickEnd', 'itemRgb', 'blockCount', 'blockSizes', 'exonStarts']
bed_header = {bed_header_ele[i]: i for i in range(len(bed_header_ele))}

plt.style.use('ggplot')
# plt.rcParams["font.family"] = "Times New Roman"
# font_family = 'monospace'  # , 'serif', 'sans-serif', 'cursive', 'fantasy', 'monospace'
font_weight = 'bold'  # 'normal', 'bold', 'large', 'black'
font_color = 'black'


def get_exon_interval_from_bed12(bed12_ele = []):
    # strand = bed12_ele[bed_header['strand']]
    itv = interval()
    start = int(bed12_ele[bed_header['chromStart']])
    start_array, len_array = bed12_ele[bed_header['exonStarts']].split(','), bed12_ele[bed_header['blockSizes']].split(
        ',')
    if '' in start_array: start_array.remove('')
    if '' in len_array: len_array.remove('')
    exon_start = [int(i) for i in start_array]
    exon_len = [int(i) for i in len_array]
    for s, l in zip(exon_start, exon_len):
        itv |= interval([start + s + 1, start + s + l])
    return itv

# return:
# {circRNA_ID:[(chr, strand, start, end), ()...]}
def get_circRNA_from_gtf(in_gtf):
    gtf_db = ut.restore_gff_db(in_gtf)

    circRNA = {}
    for exon in gtf_db.features_of_type('exon', order_by='start'):
        if exon.attributes['transcript_id'][0] not in circRNA:
            circRNA[exon.attributes['transcript_id'][0]] = [(exon.chrom, exon.strand, exon.start, exon.end)]
        else:
            circRNA[exon.attributes['transcript_id'][0]].append((exon.chrom, exon.strand, exon.start, exon.end))
    return circRNA


# return:
# {circRNA_ID:[(chr, strand, start, end), ()...]}
def get_circRNA_from_bed12(in_bed):
    # chr1    9991948 9994918 hsa_circ_0000014        1000    -       9991948 9994918 0,0,255 3       82,96,99        0,912,2871
    # chromStart: 0-base
    # exonStarts: 0-base
    circRNA = {}
    with open(in_bed, 'r') as bed:
        for line in bed:
            if line.startswith('#'): continue
            ele = line.rsplit('\t')
            chrom = ele[bed_header['chrom']]
            strand = ele[bed_header['strand']]
            start = int(ele[bed_header['chromStart']])
            exon_start = [int(i) for i in ele[bed_header['exonStarts']].split(',')]
            exon_len = [int(i) for i in ele[bed_header['blockSizes']].split(',')]
            circRNA[ele[bed_header['name']]] = []
            for s, l in zip(exon_start, exon_len):
                circRNA[ele[bed_header['name']]].append((chrom, strand, start + s + 1, start + s + l))
    return circRNA


# r.q/r_st: 0-base
# r.q/r_en: 1-base
# MISMATCH: read_pos(first), ref_pos(first), len, read_base, ref_base
# INSERTION: ins_read_pos(first), ins_ref_pos(left),  len, ins_base
# DELETION: del_read_pos(left), del_ref_pos(first), len, del_base
def get_align_detail(r, ref_seq, read_seq):
    if r.strand == -1:
        _read_seq = mp.revcomp(read_seq)
        right_clip = r.q_st
        left_clip = len(read_seq) - r.q_en 
    else:
        _read_seq = read_seq
        left_clip = r.q_st
        right_clip = len(read_seq) - r.q_en 
    cigar = [] if left_clip == 0 else [[pb.BAM_CSOFT_CLIP, left_clip]]
    for t in r.cigar:
        cigar.append(t[::-1])
    if right_clip != 0:
        cigar.append([pb.BAM_CSOFT_CLIP, right_clip])
    mdstr = r.MD
    # print r.r_st, r.r_en, mdstr, _read_seq
    mis_err, ins_err, dele_err = pb.get_error_from_MD(cigar, mdstr, _read_seq, r.r_st)
    # print(mis_err, ins_err, dele_err)
    return [r.strand == -1, len(read_seq), r.r_st, r.r_en-r.r_st, left_clip, right_clip, mis_err, ins_err, dele_err]


def generate_wedge(direction, init_degree, start, end, init_x, init_y, init_r, width, dis, color):
    patches = []
    if init_degree >= 0 and init_degree < 180:  # [0, 180)
        round = int(start) // 180
    else:  # [180, 360)
        round = (int(start) - 180) // 180

    remain_degree = end - start
    last_degree = start % 360.0

    while remain_degree > 0:
        start = last_degree
        if last_degree >= 0 and last_degree < 180:  # [0, 180)
            if last_degree + remain_degree >= 180:
                remain_degree -= (180 - last_degree)
                end = 180
                last_degree = 180
            else:
                end = last_degree + remain_degree
                remain_degree = 0
        else:  # [180, 360)
            if last_degree + remain_degree >= 360:
                remain_degree -= (360 - last_degree)
                end = 360
                last_degree = 0
            else:
                end = last_degree + remain_degree
                remain_degree = 0
        if init_degree >= 0 and init_degree < 180:  # [0, 180)
            cur_x, cur_y, cur_r = init_x + round % 2 * (width + dis) / 2, init_y, init_r + round * (width + dis) / 2
        else:
            cur_x, cur_y, cur_r = init_x - round % 2 * (width + dis) / 2, init_y, init_r + round * (width + dis) / 2

        round += 1

        patches.append(Wedge((cur_x, cur_y), cur_r, direction * start, direction * end, width=width, color=color))
    return patches


def generate_donut_ring(ax, direction, strand, start_degree, length, x, y, radius, width, edge_width, face_color, edge_color, bs_color):
    patches = []
    PI = math.pi
    len_frac = [0.0]
    N = len(length)
    tot_len = sum(length) + 0.0
    for l in length:
        len_frac.append(len_frac[-1] + l / tot_len)

    # 1. donut ring
    font_size = 12 + 20 * width
    for i in range(len(len_frac) - 1):
        start = start_degree + len_frac[i] * 360.0
        end = start_degree + len_frac[i + 1] * 360.0
        patches.append(Wedge((x, y), radius, direction * start, direction * end, width=width, facecolor=face_color, edgecolor=edge_color, linewidth=edge_width))
        # 2. exon number
        # text, offset = 'Exon' + (str(i + 1) if strand == '+' else str(N - i)), width / 3
        # ax.text(x - offset + (radius - width / 2) * math.cos((end + (start - end) / 2) * PI / 180), y + (radius - width / 2) * math.sin((end + (start - end) / 2) * PI / 180), text, fontsize=font_size, weight=font_weight, color=edge_color)
    # add red backsplicing line
    ax.add_line(mlines.Line2D([x + radius - width + width / 20, x + radius - width / 20], [y, y], linewidth=edge_width,
                      color=bs_color))

    # 3. arc and arrow
    arc_radius = radius / 3
    arc_width = 60 * radius #6.0
    arrow_length = 0.3 * width
    arrow_angle_rad = 30 * PI / 180

    if (strand == '+' and clock_wise) or (strand == '-' and not clock_wise):
        arc_start_angle = 30
        arc_end_angle = 360 - 15
        arc_start_angle_rad = arc_start_angle * PI / 180  # degrees to radians

        arrow = FancyArrow(x + arc_radius * math.cos(arc_start_angle_rad),
                           y + arc_radius * math.sin(arc_start_angle_rad),
                           arrow_length * math.sin(arrow_angle_rad), -arrow_length * math.cos(arrow_angle_rad),
                           head_width=arrow_length, head_length=arrow_length, length_includes_head=True,
                           color=face_color)
    else:
        arc_start_angle = 15
        arc_end_angle = 360 - arc_start_angle * 2
        arc_end_angle_rad = arc_end_angle * PI / 180  # degrees to radians

        arrow = FancyArrow(x + arc_radius * math.cos(arc_end_angle_rad), y + arc_radius * math.sin(arc_end_angle_rad),
                           arrow_length * math.sin(arrow_angle_rad), arrow_length * math.cos(arrow_angle_rad),
                           head_width=arrow_length, head_length=arrow_length, length_includes_head=True,
                           color=face_color)
    arc = Arc((x, y), arc_radius * 2, arc_radius * 2,  # ellipse width and height
              theta1=arc_start_angle, theta2=arc_end_angle, linewidth=arc_width, color=face_color)
    patches.append(arc)
    patches.append(arrow)

    return patches


def generate_alignment_wedge(direction, start_degree, init_degree, init_pos, read_tot_degree, align_detail, exon_tot_len, mis_color, ins_color, del_color, clip_color, read_x, read_y, read_r, read_w, read_dis):
    [read_is_reverse, read_tot_len, ref_align_start, ref_align_len, left_clip, right_clip, mis_err, ins_err,
     del_err] = align_detail
    patches = []
    # 3.1 left_clipping: [ref_start-left_clip, ref_start]
    left_clip_degree = 360.0 * left_clip / exon_tot_len
    patches.extend(
        generate_wedge(direction, init_degree, init_degree, init_degree + left_clip_degree, read_x, read_y, read_r, read_w, read_dis, clip_color))
    # 3.2 mismatch
    mis_degree = 360.0 * 1 / exon_tot_len
    for mis in mis_err:
        # mis = [read_pos(first), ref_pos(first), len, read_base, ref_base]
        for i in range(mis[2]):
            mis_start_degree = start_degree + init_degree + 360.0 * (mis[1] + i - init_pos) / exon_tot_len
            patches.extend(
                generate_wedge(direction, init_degree, mis_start_degree, mis_start_degree + mis_degree, read_x, read_y, read_r, read_w, read_dis, mis_color[mis[3][i]]))
    # 3.3 insertion
    ins_degree = 360.0 * 1 / exon_tot_len
    for ins in ins_err:
        # ins: [ins_read_pos(first), ins_ref_pos(left),  len, ins_base]
        ins_start_degree = start_degree + init_degree + 360.0 * (ins[1] - init_pos) / exon_tot_len
        patches.extend(
            generate_wedge(direction, init_degree, ins_start_degree, ins_start_degree + ins_degree, read_x, read_y, read_r, read_w, read_dis, ins_color))
    # 3.4 deletion
    for dele in del_err:
        # dele: [del_read_pos(left), del_ref_pos(first), len, del_base]
        del_degree = 360.0 * dele[2] / exon_tot_len
        del_start_degree = start_degree + init_degree + 360.0 * (dele[1] - init_pos) / exon_tot_len
        patches.extend(
            generate_wedge(direction, init_degree, del_start_degree, del_start_degree + del_degree, read_x, read_y, read_r, read_w, read_dis, del_color))
    # 3.5 right_clipping
    right_clip_degree = 360.0 * right_clip / exon_tot_len
    # print right_clip, exon_tot_len
    patches.extend(
        generate_wedge(direction, init_degree, init_degree + read_tot_degree - right_clip_degree, init_degree + read_tot_degree, read_x, read_y, read_r, read_w, read_dis, clip_color))
    return patches


# TODO: add match/misatch/in-del/unmapped statistics
def add_legend(iso_mode, ax, xlim, ylim, circRNA, circRNA_name, read_tot_len, all_itvs, isoform_itvs, gene_itvs, exon_face_color, exon_r, exon_w, read_color, read_w, bs_width, bs_color, mis_color, ins_color, del_color, clip_color):
    patches = []
    start_x, start_y, dis = 0.6 * xlim, 0.85 * ylim, 0.05

    # dis count
    circ_exon_bs_dis_cnt, match_unmapped_dis_cnt, mis_dis_cnt, indel_dis_cnt,  = 0, 1, 2, 3
    if iso_mode == 'ISO':
        gene_title_dis_cnt, gene_struct_dis_cnt, iso_title_dis_cnt, iso_struct_dis_cnt = 4, 4.5, 5.5, 6
    else: gene_struct_dis_cnt = 4.5
    read_info_dis_cnt, circ_info_dis_cnt = 7, 8

    # 1. color legend
    exon_rect_width = exon_w / 4
    exon_rect_len = 3 * exon_rect_width
    read_rect_width = read_w
    base_rect_len = 0.7 * read_rect_width
    font_size = 14 + 20 * exon_w

    exon_rect = Rectangle((start_x, start_y), exon_rect_len, exon_rect_width, color=exon_face_color)
    patches.append(exon_rect)
    ax.text(start_x + exon_rect_len + dis / 4, start_y, 'CircRNA exon', fontsize=font_size, color=font_color, weight=font_weight)

    ax.add_line(
        mlines.Line2D([start_x + 3 * dis, start_x + 3 * dis + exon_w],
                      [start_y - circ_exon_bs_dis_cnt * dis + exon_rect_width / 2, start_y - circ_exon_bs_dis_cnt * dis +exon_rect_width / 2],
                      linewidth=bs_width, color=bs_color))
    ax.text(start_x + 3 * dis + exon_rect_len + dis / 2, start_y - circ_exon_bs_dis_cnt * dis , 'Backsplicing', fontsize=font_size, color=font_color, weight=font_weight)

    read_rect = Rectangle((start_x, start_y - match_unmapped_dis_cnt * dis), exon_rect_len, read_rect_width, color=read_color)
    patches.append(read_rect)
    ax.text(start_x + exon_rect_len + dis / 4, start_y - match_unmapped_dis_cnt * dis, 'Match', fontsize=font_size, color=font_color, weight=font_weight)

    unmaped_read_rect = Rectangle((start_x + 2 * dis, start_y - match_unmapped_dis_cnt * dis), exon_rect_len, read_rect_width, color=clip_color)
    patches.append(unmaped_read_rect)
    ax.text(start_x + 2 * dis + exon_rect_len + dis / 4, start_y - match_unmapped_dis_cnt * dis, 'Unmapped', fontsize=font_size, color=font_color, weight=font_weight)

    for i, b in enumerate('ACGT'):
        rect = Rectangle((start_x + i * dis, start_y - mis_dis_cnt * dis), base_rect_len, read_rect_width, color=mis_color[b])
        patches.append(rect)
        ax.text(start_x + i * dis + base_rect_len + dis / 4, start_y - mis_dis_cnt * dis, b, fontsize=font_size, color=font_color, weight=font_weight)

    ins_rect = Rectangle((start_x, start_y - indel_dis_cnt * dis), base_rect_len, read_rect_width, color=ins_color)
    patches.append(ins_rect)
    ax.text(start_x + base_rect_len + dis / 4, start_y - indel_dis_cnt * dis, 'Insertion', fontsize=font_size, color=font_color, weight=font_weight)

    del_rect = Rectangle((start_x + 2 * dis, start_y - indel_dis_cnt * dis), base_rect_len, read_rect_width, facecolor=del_color, edgecolor=read_color)
    patches.append(del_rect)
    ax.text(start_x + 2 * dis + base_rect_len + dis / 4, start_y - indel_dis_cnt * dis, 'Deletion', fontsize=font_size, color=font_color, weight=font_weight)

    # 2. add gene structure and isoform structure based on gene_itvs and isoform_itvs
    font_size = 16 + 20 * exon_w
    ## 2.1 collect coor_to_linear mapping dict
    all_coor_to_linear = dict()
    tot_len = math.pi * 2 * (exon_r - exon_w)
    intron_scale = 20.0
    exon_len = [itv[1] - itv[0] + 1 for itv in all_itvs]
    intron_len = []
    last = all_itvs[0][0]
    for itv in gene_itvs:
        l = itv[0] - last - 1
        last = itv[1]
        if l > 0: intron_len.append(l)
    exon_N, exon_tot_len, intron_tot_len = len(exon_len), sum(exon_len), sum(intron_len)
    # exon_len_sum / scale + intron_len_sum / (intron_scale * scale) = tot_len
    scale = (intron_scale * exon_tot_len + intron_tot_len) / tot_len / intron_scale
    exon_linear_len, intron_linear_len = [], []
    for l in exon_len:
        exon_linear_len.append(l / scale) 
    for l in intron_len:
        intron_linear_len.append(l / (intron_scale * scale))
    for i, (l, itv) in enumerate(zip(exon_linear_len, all_itvs)):
        all_coor_to_linear[(itv[0], itv[1])] = (start_x + sum(exon_linear_len[:i]) + sum(intron_linear_len[:i]), l)

    # strand arrow 
    gene_line_width, end_intron_linear_len  = 2, exon_rect_width
    strand = circRNA[0][1]
    arrow_line_width, arrow_dis, arrow_scale, arrow_size = 1, exon_rect_width, 20, exon_rect_width/2
    style = '->' if strand == '+' else '<-'
    arrow_start_x = start_x - end_intron_linear_len
    arrow_len = tot_len + 2 * end_intron_linear_len
    ## 2.2 gene_itvs
    if iso_mode == 'ISO':
        ax.text(start_x + 1.5 * dis, start_y - gene_title_dis_cnt * dis, u'Gene structure', fontsize=font_size, weight=font_weight)
    # exon_len = [itv[1] - itv[0] + 1 for itv in gene_itvs]
    # exon_N, exon_tot_len = len(exon_len), sum(exon_len)
    # intron_len = []
    # last = gene_itvs[0][0]
    # for itv in gene_itvs:
    #     l = itv[0] - last - 1
    #     last = itv[1]
    #     if l > 0: intron_len.append(l)
    # intron_tot_len = sum(intron_len)
    # scale = (intron_scale * exon_tot_len + intron_tot_len) / tot_len / intron_scale
    # exon_linear_len, intron_linear_len = [], []
    # for l in exon_len:
    #     exon_linear_len.append(l / scale) 
    # for l in intron_len:
    #     intron_linear_len.append(l / (intron_scale * scale))
    # linear_coor_list = []
    # for i, (l, itv) in enumerate(zip(exon_linear_len, gene_itvs)):
    #     exon_linear_rect = Rectangle((start_x + sum(exon_linear_len[:i]) + sum(intron_linear_len[:i]), start_y - 4.5 * dis), l, exon_rect_width, color=exon_face_color)
    #     patches.append(exon_linear_rect)
    #     coor_to_linear[(itv[0], itv[1])] = (start_x + sum(exon_linear_len[:i]) + sum(intron_linear_len[:i]), l)
    #     linear_coor_list.extend([start_x + sum(exon_linear_len[:i]) + sum(intron_linear_len[:i]), start_x + sum(exon_linear_len[:i]) + sum(intron_linear_len[:i])+l])
    # gene_line_width = 2
    # end_intron_linear_len = exon_rect_width
    # ax.add_line(mlines.Line2D([start_x - end_intron_linear_len, start_x + sum(exon_linear_len) + sum(intron_linear_len) + end_intron_linear_len], [start_y - 4.5 * dis + exon_rect_width / 2, start_y - 4.5 * dis + exon_rect_width / 2], linewidth=gene_line_width, color=exon_face_color))
    # ## gene strand arrow
    # strand = circRNA[0][1]
    # arrow_line_width, arrow_dis, arrow_scale, arrow_size = 1, exon_rect_width, 20, arrow_dis/2
    # style = '->' if strand == '+' else '<-'
    # arrow_y = start_y - 4.5 * dis + exon_rect_width / 2
    # arrow_start_x = linear_coor_list[0] - end_intron_linear_len
    # arrow_len = linear_coor_list[-1] - linear_coor_list[0] + 2 * end_intron_linear_len
    # for j in range(int(arrow_len / arrow_dis)):
    #     arrow = FancyArrowPatch((arrow_start_x + (j+1) * arrow_dis-arrow_size, arrow_y), (arrow_start_x + (j+1) * arrow_dis, arrow_y), arrowstyle=style, linewidth=arrow_line_width, mutation_scale=arrow_scale, color=exon_face_color)
    #     patches.append(arrow)

    ### NEW START
    ## exon block
    linear_coor_list = []
    for itv in gene_itvs:
        (s, e) = itv
        hit = 0
        for (_s, _e), (linear_start_coor, linear_len) in all_coor_to_linear.items():
            if _s <= s and _e >= e:
                hit = 1
                start = (s - _s) / (_e - _s+1.0) * linear_len + linear_start_coor
                rect_len = (e - s + 1) / (_e - _s + 1.0) * linear_len
                exon_linear_rect = Rectangle((start, start_y - gene_struct_dis_cnt * dis), rect_len, exon_rect_width, color=exon_face_color)
                patches.append(exon_linear_rect)
                break
        if hit == 0:
            ut.fatal_format_time('gene_itv', 'Inconsistent isoform and gene structure.')
    ## intron line
    ax.add_line(mlines.Line2D([arrow_start_x, arrow_start_x + arrow_len], [start_y - gene_struct_dis_cnt * dis + exon_rect_width / 2, start_y - gene_struct_dis_cnt * dis + exon_rect_width / 2], linewidth=gene_line_width, color=exon_face_color))
    ## gene strand arrow
    arrow_y = start_y - gene_struct_dis_cnt * dis + exon_rect_width / 2
    for j in range(int(arrow_len / arrow_dis)):
        start = arrow_start_x + (j+1) * arrow_dis-arrow_size
        end = arrow_start_x + (j+1) * arrow_dis
        arrow = FancyArrowPatch((start, arrow_y), (end, arrow_y), arrowstyle=style, linewidth=arrow_line_width, mutation_scale=arrow_scale, color=exon_face_color)
        patches.append(arrow)
    ### NEW END

    # collect exon/intron coordinate : plot coordinate mapping
    # (start, end) : (start_coor, linear_len)
    ## 2.3 isoform_itvs
    if iso_mode == 'ISO':
        ax.text(start_x + 1.5 * dis, start_y - iso_title_dis_cnt * dis, u'circRNA isoform structure', fontsize=font_size, weight=font_weight)
    max_coor, min_coor = -1, sys.maxsize
    linear_coor_list = []
    for itv in isoform_itvs:
        (s, e) = itv
        hit = 0
        for (_s, _e), (linear_start_coor, linear_len) in all_coor_to_linear.items():
            if _s <= s and _e >= e:
                hit = 1
                start = (s - _s) / (_e - _s+1.0) * linear_len + linear_start_coor
                rect_len = (e - s + 1) / (_e - _s + 1.0) * linear_len
                linear_coor_list.extend([start, start + rect_len])
                if iso_mode == 'ISO':
                    exon_linear_rect = Rectangle((start, start_y - iso_struct_dis_cnt * dis), rect_len, exon_rect_width, color=exon_face_color)
                    patches.append(exon_linear_rect)
                if start + rect_len / 2 < min_coor: min_coor = start + rect_len / 2 
                if start + rect_len / 2> max_coor: max_coor = start + rect_len / 2
                break
        if hit == 0:
            ut.fatal_format_time('isoform_itv', 'Inconsistent isoform and gene structure.')
    # splice junction on gene structure
    if iso_mode == 'SJ':
        splice_line_width=3
        margin_diff = 0.0005
        for i in range(len(linear_coor_list) // 2 - 1):
            start, end = linear_coor_list[i*2+1], linear_coor_list[i*2+2]
            mid = (start + end) / 2
            ax.add_line(mlines.Line2D([start-margin_diff, mid-margin_diff], [start_y - gene_struct_dis_cnt * dis + exon_rect_width - margin_diff, start_y - (gene_struct_dis_cnt - 0.5) * dis + exon_rect_width], linewidth=splice_line_width, color=exon_face_color))
            ax.add_line(mlines.Line2D([mid+margin_diff, end+margin_diff], [start_y - (gene_struct_dis_cnt - 0.5) * dis + exon_rect_width, start_y - gene_struct_dis_cnt * dis + exon_rect_width-margin_diff], linewidth=splice_line_width, color=exon_face_color))
        start, end = linear_coor_list[0], linear_coor_list[-1]
        mid = (start + end) / 2
        ax.add_line(mlines.Line2D([start+2*margin_diff, mid-2*margin_diff], [start_y - gene_struct_dis_cnt * dis+margin_diff, start_y - (gene_struct_dis_cnt + 0.5) * dis], linewidth=bs_width, color=bs_color))
        ax.add_line(mlines.Line2D([mid+2*margin_diff, end-2*margin_diff], [start_y - (gene_struct_dis_cnt + 0.5) * dis, start_y - gene_struct_dis_cnt * dis + margin_diff], linewidth=bs_width, color=bs_color))
    else: # 'ISO'
        # isoform line
        ax.add_line(mlines.Line2D([min_coor, max_coor], [start_y - iso_struct_dis_cnt * dis + exon_rect_width / 2, start_y - iso_struct_dis_cnt * dis + exon_rect_width / 2], linewidth=gene_line_width, color=exon_face_color))
        # strand arrow
        arrow_y = start_y - iso_struct_dis_cnt * dis + exon_rect_width / 2
        for j in range(int(arrow_len / arrow_dis)):
            start = arrow_start_x + (j+1) * arrow_dis-arrow_size
            end = arrow_start_x + (j+1) * arrow_dis
            if start > min_coor and end < max_coor:
                arrow = FancyArrowPatch((start, arrow_y), (end, arrow_y), arrowstyle=style, linewidth=arrow_line_width, mutation_scale=arrow_scale, color=exon_face_color)
                patches.append(arrow)
            elif start >= max_coor: break

    # 3. read and circRNA information
    font_size = 16 + 20 * exon_w
    ax.text(start_x, start_y - read_info_dis_cnt * dis, u'\u2022 CircRNA long read: {:,} bp'.format(read_tot_len), fontsize=font_size, color=font_color, weight=font_weight)
    # {circRNA_ID:[(chr, strand, start, end), ()...]}
    ax.text(start_x, start_y - circ_info_dis_cnt * dis, u'\u2022 {}: {:,} bp\n\u2022 {} {} : {:,} - {:,}'.format(circRNA_name, int(exon_tot_len), circRNA[0][1], circRNA[0][0], circRNA[0][2], circRNA[-1][3]), fontsize=font_size, color=font_color, weight=font_weight)
    return patches


def nano_circ_plot(align_detail, read_name, circRNA, circRNA_name, gene_id, all_itvs, isoform_itvs, gene_itvs, out_dir, iso_mode, clock_wise):
    read_name = re.sub('/', '_', read_name)
    fig_name = '{}/{}_{}_{}_{}.png'.format(out_dir, gene_id, circRNA_name, read_name, iso_mode)
    ut.err_format_time('IsoCircPlot', 'Plotting ' + fig_name)
    mpl.style.use('ggplot')
    patches = []
    start_degree = 0.0
    direction = -1 if clock_wise else 1
    # rectangle
    fig = plt.figure(figsize=(20, 10))
    x_width, xlim, y_width, ylim = 1.0, 1.0, 2.0, 0.5
    ax = fig.add_axes([0.0, 0.0, x_width, y_width], frameon=True, aspect=1)
    default_x, default_y, default_r, default_w, default_edge_width = 0.3, 0.25, 0.1, 0.05, 4

    # square
    # fig = plt.figure(figsize=(10, 10))
    # ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frameon=True, aspect=1)
    # x, y, r, w, edge_width = 0.5, 0.5, 0.2, 0.1, 1.5

    # 0. init plot
    exon_len = [l[3] - l[2] + 1 for l in circRNA]
    exon_tot_len = sum(exon_len) + 0.0
    [read_is_reverse, read_tot_len, ref_align_start, ref_align_len, left_clip, right_clip, mis_err, ins_err, del_err] = align_detail
    read_cov_tot_len = ref_align_len + left_clip + right_clip + 0.0
    cov = int(read_cov_tot_len / exon_tot_len) + 1

    x, y, r, w, edge_width = default_x, default_y, default_r, default_w, default_edge_width
    if r * (1 + 0.25 + cov * 13 / 120.0) > 0.25:
        r = 0.25 / (1 + 0.25 + cov * 13 / 120.0)
        w = r / 2

    # 1. donut ring of circRNA exons
    exon_x, exon_y, exon_r, exon_w, exon_edge_width = x, y, r, w, edge_width
    exon_edge_color, exon_face_color, exon_bs_color = 'white', 'black', 'red'
    strand = circRNA[0][1]
    patches.extend(generate_donut_ring(ax, direction, strand, start_degree, exon_len, exon_x, exon_y, exon_r, exon_w, exon_edge_width, exon_face_color, exon_edge_color, exon_bs_color))

    # 2. wedge of read
    read_x, read_y, read_r, read_w, read_dis, read_color = x, y, r + exon_w / 2, exon_w / 6, exon_w / 20, 'C1'  # '#1874CD' #'dodgerblue3', #'#009ACD' #'deepskyblue3'

    read_is_reverse = align_detail[0]
    init_pos = ref_align_start - left_clip
    ref_start = (ref_align_start - left_clip) % exon_tot_len
    read_cov_tot_len = ref_align_len + left_clip + right_clip

    read_tot_degree = 360.0 * read_cov_tot_len / exon_tot_len
    init_degree = (start_degree + 360.0 * ref_start / exon_tot_len) % 360.0

    patches.extend(generate_wedge(direction, init_degree, init_degree, init_degree + read_tot_degree, read_x, read_y, read_r, read_w, read_dis, read_color))
    # 3. wedge of alignment errors
    ins_color, del_color, clip_color = 'black', 'white', 'C3'
    mis_color = {'A': 'C2', 'a': 'C2', 'C': 'C0', 'c': 'C0', 'G': 'C5', 'g': 'C5', 'T': 'C4', 't': 'C4'}
    patches.extend(generate_alignment_wedge(direction, start_degree, init_degree, init_pos, read_tot_degree, align_detail, exon_tot_len, mis_color, ins_color, del_color, clip_color, read_x, read_y, read_r, read_w, read_dis))
    # 4. legend
    exon_r, exon_w, read_w = default_r, default_w, default_w / 6
    patches.extend(add_legend(iso_mode, ax, xlim, ylim, circRNA, circRNA_name, read_tot_len, all_itvs, isoform_itvs, gene_itvs, exon_face_color, exon_r, exon_w, read_color, read_w, edge_width, exon_bs_color, mis_color, ins_color, del_color, clip_color))
    for i in patches:
        ax.add_patch(i)
    # ax.grid(color='r', linestyle='-', linewidth=2)
    plt.axis('off')
    plt.savefig(fig_name, format='png', dpi=500)
    ut.err_format_time('IsoCircPlot', 'Plotting done!')
    return


def plot_main(isocirc_bed, align_details, circRNA_fa, all_struct, isoform_struct, gene_struct, out_dir, mode, clock_wise):
    # 1. extract circRNA information
    all_circRNA = get_circRNA_from_bed12(isocirc_bed)

    for read_name, (read_seq, ref_name, ref_seq, gene_id, r) in align_details.items():
        # print('{}\t{}'.format(read_name, r.cigar_str))
        # 1. extract alignment details
        align_detail = get_align_detail(r, ref_seq, read_seq)
        # print r.cigar_str, r.MD, align_detail
        # 2. plot donut and wedge for circRNA and read
        circRNA = all_circRNA[ref_name]
        nano_circ_plot(align_detail, read_name, circRNA, ref_name, gene_id, all_struct[gene_id], isoform_struct[ref_name], gene_struct[gene_id], out_dir, mode, clock_wise)
    return

# 1. get isoform structure for each isoform
# 2. add isoform.exon to gene.exon
# ==> isoform_struct[read_id] = interval()
# ==> gene_struct[gene_id] = interval()
def align_to_circRNA_fa(ref_fa, anno_gtf, read_fa, read_len_fn, isocirc_bed, isocirc_read_list, out_dir, circRNA_ref_fa):
    circRNA_read_fa = out_dir + '/circRNA_read.fa'
    gene_gtf = out_dir + '/gene.gtf'
    isocirc_list, read_list = out_dir + '/isocirc.list', out_dir + '/read.list'

    if os.path.exists(gene_gtf): ut.exec_cmd(sys.stderr, 'Old files', 'rm {}'.format(gene_gtf))
    if os.path.exists(circRNA_ref_fa): ut.exec_cmd(sys.stderr, 'Old files', 'rm {}'.format(circRNA_ref_fa))
    if os.path.exists(circRNA_read_fa): ut.exec_cmd(sys.stderr, 'Old files', 'rm {}'.format(circRNA_read_fa))

    # 1. generate circRNA.fa
    copy_n, circRNA_blocks, read_len = dict(), dict(), dict()
    iso_id_list, read_iso_id_list = dict(), dict()
    read_gene_list, gene_ids = dict(), dict()
    all_struct, isoform_struct, gene_struct = dict(), dict(), dict()
    with open(isocirc_read_list, 'r') as name_fp, open(isocirc_bed) as isocirc_fp, open(read_len_fn) as read_len_fp, open(isocirc_list, 'w') as isocirc_list_fp, open(read_list, 'w') as read_list_fp:
        for line in name_fp:
            if line.startswith('#'): continue
            ele = line.rsplit()
            isocirc_name, read_names, gene_id = ele[0], ele[1], ele[2]
            all_struct[gene_id] = interval()
            gene_struct[gene_id] = interval()
            iso_id_list[isocirc_name] = gene_id
            for read_name in read_names.rsplit(','):
                read_list_fp.write('{}\n'.format(read_name))
                read_iso_id_list[read_name] = isocirc_name
                read_gene_list[read_name] = gene_id

        for line in isocirc_fp:
            if line.startswith('#'): continue
            ele = line.rsplit()
            if ele[3] in iso_id_list:
                circRNA_blocks[ele[3]] = map(int, ele[10].rsplit(','))
                isocirc_list_fp.write(line)
                isoform_struct[ele[3]] = get_exon_interval_from_bed12(ele)
                gene_id = iso_id_list[ele[3]]
                all_struct[gene_id] |= isoform_struct[ele[3]]

        for line in read_len_fp:
            if line.startswith('#'): continue
            ele = line.rsplit()
            read_len[ele[0]] = int(ele[1])

    with open(isocirc_read_list, 'r') as name_fp:
        for line in name_fp:
            if line.startswith('#'): continue
            ele = line.rsplit()
            isocirc_name, read_names = ele[0], ele[1]
            block_len = sum(circRNA_blocks[isocirc_name])
            for read_name in read_names.rsplit(','):
                copy_n[read_name] = read_len[read_name] // block_len + 2
        
    # read.fa
    cmd = 'fxtools fn -l -n {} {} > {}'.format(read_list, read_fa, circRNA_read_fa)
    ut.exec_cmd(sys.stderr, 'circRNA read sequence', cmd)
    # ref.fa copy_n copyies
    cmd = 'bedtools getfasta -split -nameOnly -fi {} -bed {} -fo {}'.format(ref_fa, isocirc_list, circRNA_ref_fa)
    ut.exec_cmd(sys.stderr, 'circRNA reference sequence', cmd)
    ut.exec_cmd(sys.stderr, 'Temporary files', 'rm {} {}'.format(isocirc_list, read_list))

    # exon structure from annotation file
    for gene_id in all_struct:
        cmd = 'grep {} {} > {}'.format(gene_id, anno_gtf, gene_gtf)
        ut.exec_cmd(sys.stderr, 'Gene structure from annotation file', cmd)
        with open(gene_gtf) as gene_fp:
            for line in gene_fp:
                if line.startswith('#'): continue
                ele = line.rsplit()
                if ele[2] != 'exon': continue
                if int(ele[3]) > all_struct[gene_id][-1][-1] or int(ele[4]) < all_struct[gene_id][0][0]: continue
                gene_struct[gene_id] |= interval([int(ele[3]), int(ele[4])])
        all_struct[gene_id] |= gene_struct[gene_id]

    # print struct
    ## all
    # for gene_id, struct in all_struct.iteritems():
    #     print(gene_id), len(struct)
    #     for itv in struct:
    #         print(itv),
    #     print()  ## isoform
    # for iso_id, struct in isoform_struct.iteritems():
    #     print(iso_id), len(struct)
    #     for itv in struct:
    #         print(itv),
    #     print()
    # ## gene
    # for gene_id, struct in gene_struct.iteritems():
    #     print(gene_id), len(struct)
    #     for itv in struct:
    #         print(itv),
    #     print()
    # sys.exit(1)
    # 2. align read to circRNA_ref.fa
    read_seqs, ref_seqs = dict(), dict()
    align_details = dict() # read_name:(alignment_record)
    with Fasta(circRNA_read_fa) as read_fp, Fasta(circRNA_ref_fa) as ref_fp:
        for record in ref_fp:
            ref_seqs[record.name] = str(record)
        for record in read_fp:
            read_seq = str(record)
            ref_name = read_iso_id_list[record.name]
            ref_seq = ref_seqs[ref_name]
            a = mp.Aligner(seq=ref_seq * int(copy_n[record.name]))
            for h in a.map(read_seq, MD=True):
                if h.is_primary:
                    # print('{} {}'.format(record.name, h.cigar_str))
                    align_details[record.name] = (read_seq, ref_name, ref_seq * int(copy_n[record.name]), read_gene_list[record.name], h)

    ut.exec_cmd(sys.stderr, 'Temporary files', 'rm {}'.format(circRNA_read_fa))
    return align_details, all_struct, isoform_struct, gene_struct


"""
ref.fa: genome sequence in fasta format
read.fa: isocirc long-read
isocirc.bed: detailed information of circRNA in bed12 format
isocirc_read.list: isocirc_name read1,read2,... gene_id
# basically, each gene has one list file
output_dir: directory for plot of circRNA and long read sequence alignment 
"""
if __name__ == '__main__':
    if len(sys.argv) != 8:
        print('Usage:')
        print('{} long_read.fa ref.fa anno.gtf isocirc.bed isocirc_read.list type(SJ/ISO) output_dir'.format(sys.argv[0]))
        print('\t\tisocirc_read.list consists of 3 columns:')
        print('\t\t\tisocirc ID')
        print('\t\t\treadID(s), multiple reads can be separated by \',\'')
        print('\t\t\tgene name')
        sys.exit(1)

    [read_fa, ref_fa, anno_gtf, isocirc_bed, isocirc_read_list, mode, out_dir] = sys.argv[1:]

    if mode.upper() != 'SJ' and mode.upper() != 'ISO':
        ut.fatal_format_time('argv', 'Unknown plotting mode: {}.'.format(mode))
    mode = mode.upper()

    if not os.path.exists(out_dir):
        cmd = 'mkdir {}'.format(out_dir)
        ut.exec_cmd(sys.stderr, 'output_dir', cmd)
    circRNA_ref_fa = out_dir + '/circRNA_ref.fa'
    read_fa_len = out_dir + '/' + os.path.basename(read_fa) + '.len'
    if not os.path.exists(read_fa_len):
        cmd = 'fxtools lp {} > {}'.format(read_fa, read_fa_len)
        ut.exec_cmd(sys.stderr, 'Read length', cmd)
    align_details, all_struct, isoform_struct, gene_struct = align_to_circRNA_fa(ref_fa, anno_gtf, read_fa, read_fa_len, isocirc_bed, isocirc_read_list, out_dir, circRNA_ref_fa)
    clock_wise = False #True
    plot_main(isocirc_bed, align_details, circRNA_ref_fa, all_struct, isoform_struct, gene_struct, out_dir, mode, clock_wise)
    ut.exec_cmd(sys.stderr, 'Temporary files', 'rm {}'.format(circRNA_ref_fa))
