#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals

import csv
import logging
import re
import sys
import traceback

from bioutils.sequences import reverse_complement, complement

from hgvs.dataproviders.uta import connect
from hgvs.edit import NARefAlt
from hgvs.location import Interval, SimplePosition, BaseOffsetPosition
from hgvs.normalizer import Normalizer
from hgvs.parser import Parser
from hgvs.posedit import PosEdit
from hgvs.variant import SequenceVariant
from hgvs.variantmapper import EasyVariantMapper
import hgvs.utils.context as huc

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)

with_context = False
cigarop_re = re.compile('(?P<l>\d+)(?P<op>[=IDX])')

query = """
SELECT * FROM tx_exon_aln_v D
WHERE alt_ac~'^NC_0000' and alt_aln_method='splign' and cigar~'[DIX]' AND D.tx_ac=%(tx_ac)s
ORDER BY D.ord
"""


class COp(object):
    __slots__ = "l op a_s a_e r_s r_e q_s q_e".split()

    def __init__(self, l, op):
        self.l = int(l)
        self.op = op


def cigar_ops(cigar):
    cops = [COp(l=m.group('l'), op=m.group('op')) for m in cigarop_re.finditer(cigar)]
    r0 = q0 = a0 = 0
    for cop in cops:
        cop.a_s = a0
        cop.r_s = r0
        cop.q_s = q0
        a0 += cop.l
        r0 += cop.l if cop.op in '=MXD' else 0
        q0 += cop.l if cop.op in '=MXI' else 0
        cop.a_e = a0
        cop.r_e = r0
        cop.q_e = q0
    return cops


def generate_variants(row):
    """given a query row, return genomic variants corresponding to mismatches

    Each input row represents the alignment of corresponding
    transcript and genomic exons, with transcript as reference.

           r0      rs  re                  r1
           \-------lrs lre                 \
            \       \   \                   \
         >>> ACGTACGT----ACGTACGTACGTACGTACGT >>>
                8=    4D    8=    4I    8=      
         >>> ACGTACGTACGTACGTACGT----ACGTACGT >>>
            /       /   /                   /
           /--------lqs lqe                /
           q0       qs  qe                q1
     
     
           r0      rs  re                  r1
           \-------lrs lre                 \
            \       \   \                   \
         >>> ACGTACGTACGTACGTACGT----ACGTACGT >>>
                8=    4D    8=    4I    8=
         <<< ACGTACGT----ACGTACGTACGTACGTACGT <<<
             \       \   \                   \
              \-------lqs lqe                 \
               q1     qe  qs                   q0 (but q0<q1!)

    In UTA, cigars are always transcript-as-reference (genome as query)

    All coordinates here are interbase coordinates, but converted to 1, closed on formatting
    lrs, lre, lqs, lqe = local ref/query start/end
    rs, re, qs, qe = ref/query start/end (sequence "global" coordinates)
    r0, r1, q0, q1 = global ref, query coords of alignment segment (i.e., per seq accession)
    Note that local coords are measured left-to-right above; q0 < q1 for + and - strands

    Strandedness is accounted for when converting from local to global as follows:
    For + strand alignment:
      rs, re = r0+lrs, r0+lre
      qs, qe = q0+lqs, q0+lqe
    For - strand alignment:
      rs, re = r0+lrs, r0+lre   (same as +)
      qs, qe = q1-lqe, q1-lqs   (compare with + case)

    """

    # loop over cigar ops, incrementing transcript offset (to) and
    # genomic offset (go). Also, yield a variant for each
    # non-identity.
    for cop in cigar_ops(row['cigar']):
        if cop.op != '=':
            lrs, lre = cop.r_s, cop.r_e
            lqs, lqe = cop.q_s, cop.q_e

            tseq = row['tx_aseq'][cop.a_s:cop.a_e].replace('-', '') or None
            gseq = row['alt_aseq'][cop.a_s:cop.a_e].replace('-', '') or None

            # _h is hgvs coord
            rs_h, re_h = row['tx_start_i'] + lrs + 1, row['tx_start_i'] + lre
            rr, ra = tseq, gseq
            if row['alt_strand'] == 1:
                qs_h, qe_h = row['alt_start_i'] + lqs + 1, row['alt_start_i'] + lqe
                qr, qa = ra, rr
            elif row['alt_strand'] == -1:
                # start/end and ref/alt intentionally swapped
                qs_h, qe_h = row['alt_end_i'] - lqe + 1, row['alt_end_i'] - lqs
                qr, qa = reverse_complement(ra), reverse_complement(rr)

            # At zero-width interbase coordinate i, an in insertion is
            # hgvs coordinate [i,i+1] i.e., the surrounding nucleotides
            if cop.op == 'I':
                rs_h -= 1
                re_h += 1
            elif cop.op == 'D':
                qs_h -= 1
                qe_h += 1

            var_n = SequenceVariant(ac=row['tx_ac'],
                                    type='n',
                                    posedit=PosEdit(pos=Interval(start=BaseOffsetPosition(base=rs_h),
                                                                 end=BaseOffsetPosition(base=re_h)),
                                                    edit=NARefAlt(ref=rr,
                                                                  alt=ra)))
            var_g = SequenceVariant(ac=row['alt_ac'],
                                    type='g',
                                    posedit=PosEdit(pos=Interval(start=SimplePosition(base=qs_h),
                                                                 end=SimplePosition(base=qe_h)),
                                                    edit=NARefAlt(ref=qr,
                                                                  alt=qa)))

            yield var_n, var_g


if __name__ == "__main__":
    hdp = connect()
    hp = Parser()
    evm = EasyVariantMapper(hdp, normalize=False, replace_reference=False)
    hn3 = Normalizer(hdp, shuffle_direction=3)
    hn5 = Normalizer(hdp, shuffle_direction=5)

    fieldnames = 'hgvs_g hgvs_n hgvs_c tx_ac strand exon cigar comments'.split()
    out = csv.DictWriter(sys.stdout, fieldnames=fieldnames, delimiter=b'\t', lineterminator='\n')
    out.writeheader()

    tx_acs = sys.argv[1:]

    for tx_ac in tx_acs:
        cur = hdp._execute(query, {'tx_ac': tx_ac})
        for row in cur:
            try:
                for var_n0, var_g0 in generate_variants(row):
                    comments = []
                    var_n = hn3.normalize(var_n0)
                    if var_n != var_n0:
                        comments += [str(var_n0) + "->" + str(var_n)]
                    var_g = hn5.normalize(var_g0)
                    if var_g != var_g0:
                        comments += [str(var_g0) + "->" + str(var_g)]
                    var_c = evm.n_to_c(var_n)
                    out.writerow({
                        'hgvs_g': str(var_g),
                        'hgvs_n': str(var_n),
                        'hgvs_c': str(var_c),
                        'tx_ac': row['tx_ac'],
                        'strand': row['alt_strand'],
                        'exon': row['ord'] + 1,
                        'cigar': row['cigar'],
                        'comments': "; ".join(comments)
                    })

                    if with_context:
                        c = huc.variant_context_w_alignment(evm, var_g, tx_ac=var_c.ac)
                        c = re.sub('^', '## ', c, flags=re.M) + "\n"
                        sys.stdout.write(c)
            except Exception as e:
                logger.warn("# {row[tx_ac]}, exon {row[ord]}: {e}".format(row=row,
                                                                          e=traceback.format_exception_only(
                                                                              type(e), e)[0]))
