#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals
"""compare hgvs and mutalyzer outputs

eg$ ./bin/hgvs-mutalyzer-compare -vv -eid -H acmg-g-current-hgvs.tsv -M acmg-g-current-mzr.tsv -B bermuda.tsv

The input files must be formatted as TSV, following Mutalyzer's positionConverter format:

Input Variant	Errors	Chromosomal Variant	Coding Variant(s)
NC_000001.10:g.100586496G>C		NC_000001.10:g.100586496G>C	NM_194292.1:c.483+438C>G	XM_005270551.1:c.-19+438C>G

For each matching input (g.) variant, compare the set of c. variants
computed by each tool.  The cases are:

empty -- both tools return empty sets
complete match -- both tools return identical sets
intersection match -- all intersection keys match
subset match -- some (not all) intersection keys metch
failed -- no intersection keys match

"""

import argparse
import codecs
import collections
import csv
import gzip
import logging
import os
import re
import sys


def parse_args(argv):
    # parse command line for configuration files
    ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    ap.add_argument('--bermuda-filename', '-B', required=True)
    ap.add_argument('--hgvs-filename', '-H', required=True)
    ap.add_argument('--match-eq-delins-on-location', '-e', default=False, action='store_true')
    ap.add_argument('--mutalyzer-filename', '-M', required=True)
    ap.add_argument('--rewrite-identity-substitutions', '-i', default=False, action='store_true')
    ap.add_argument('--strip-delins', '-d', default=False, action='store_true')
    ap.add_argument('--verbose', '-v', default=0, action='count')
    args = ap.parse_args(argv)
    return args


def gzopen(fn, mode='rb'):
    return gzip.open(fn, mode) if fn.endswith('.gz') else open(fn, mode)


def read_pc_format(fn):
    """open file and return generator of dictionaries of positionConverter data.
    c. variants are grouped into a set of variants.
    """

    fh = gzopen(fn, 'r')
    hdr = fh.readline()
    assert hdr == 'Input Variant\tErrors\tChromosomal Variant\tCoding Variant(s)\n'
    for line in fh:
        if line.startswith("#"):
            continue
        vals = line.strip('\n\r').split('\t')
        d = {
            'input': vals[0],
            'errors': vals[1],
            'g_var': vals[2],
            'c_vars': [c_var for c_var in vals[3:] if c_var.startswith('NM_')]
        }
        if d['errors']:
            continue
        yield d


def read_pc(fn):
    return {d['input']: d for d in read_pc_format(fn)}


def strip_delins(recs):
    for v in recs.itervalues():
        v['c_vars'] = [re.sub('del[\dACGT]+ins', 'delins', cv) for cv in v['c_vars']]


def rewrite_identity_substitutions(recs):
    for v in recs.itervalues():
        v['c_vars'] = [re.sub(r'([ACGT])>\1', '=', cv) for cv in v['c_vars']]


def vars_match(opts, hv, mv):
    if hv == mv:
        return True
    if opts.match_eq_delins_on_location:
        hv = hv.rstrip('=')
        mv = re.sub('delins.+', '', mv)
        return hv == mv
    return False


if __name__ == '__main__':
    logging.basicConfig(level=logging.WARN)
    logger = logging.getLogger(__name__)

    opts = parse_args(sys.argv[1:])
    if opts.verbose:
        logger.setLevel(logging.INFO if opts.verbose == 1 else logging.DEBUG)

    bermuda = {r['tx_ac']: r for r in csv.DictReader(gzopen(opts.bermuda_filename, 'r'), delimiter=str('\t'))}
    dirty_acs = set(k for k, r in bermuda.iteritems() if ('D' in r['s_status'] or 'I' in r['s_status']))

    h_recs = read_pc(opts.hgvs_filename)
    logger.info("read {n} genomic variants from {opts.hgvs_filename}".format(n=len(h_recs), opts=opts))

    m_recs = read_pc(opts.mutalyzer_filename)
    logger.info("read {n} genomic variants from {opts.mutalyzer_filename}".format(n=len(m_recs), opts=opts))

    if opts.strip_delins:
        strip_delins(h_recs)
        strip_delins(m_recs)
    if opts.rewrite_identity_substitutions:
        rewrite_identity_substitutions(h_recs)
        rewrite_identity_substitutions(m_recs)

    g_hgvs_keys = set(h_recs.keys()) & set(m_recs.keys())
    logger.info("{n} genomic variants in common".format(n=len(g_hgvs_keys)))

    match_bins = collections.defaultdict(lambda: set())
    match_counts = collections.Counter()
    hmk_tot = 0
    h_missing = set()
    m_missing = set()

    for g_hgvs in sorted(g_hgvs_keys):
        h_rec = h_recs[g_hgvs] if g_hgvs in h_recs else None
        m_rec = m_recs[g_hgvs] if g_hgvs in m_recs else None

        if not h_rec:
            logger.warn("no hgvs record for " + g_hgvs)
            continue
        elif not m_rec:
            logger.warn("no Mutalyzer record for " + g_hgvs)
            continue
        assert h_rec is not None and m_rec is not None

        h_cs_d = {v.partition(':')[0]: str(v) for v in h_rec['c_vars']}
        m_cs_d = {v.partition(':')[0]: str(v) for v in m_rec['c_vars']}

        # "keys" are transcript accessions for most of the following code
        hk = set(h_cs_d.keys())    # hgvs keys
        mk = set(m_cs_d.keys())    # mzr keys
        hok = hk - mk    # hgvs only keys
        mok = mk - hk    # mzr only keys
        hmk = hk & mk    # intersection keys

        eqk = set(k for k in hmk if vars_match(opts, h_cs_d[k], m_cs_d[k]))    # match keys
        nek = set(k for k in hmk if not vars_match(opts, h_cs_d[k], m_cs_d[k]))    # mismatch keys
        nek_dirty = nek & dirty_acs
        nek_clean = nek - dirty_acs

        hmk_tot += len(hmk)

        match = None
        if len(hmk) == 0:
            # no transcript accessions in common
            if len(hk) == len(mk) == 0: match = 'two-sided empty'
            elif len(hk) == 0:
                match = 'one-sided empty (hgvs)'
                h_missing |= mk
            elif len(mk) == 0:
                match = 'one-sided empty (mutalyzer)'
                m_missing |= hk
        elif len(eqk) == len(hmk):
            # all intersection keys eq
            if False and len(eqk) == len(mk) == len(hk): match = 'perfect match'
            else: match = 'intersection match'
        elif len(eqk) < len(hmk):
            # some mismatches
            assert len(nek_clean) + len(nek_dirty) == len(nek)
            if len(nek_dirty) == 0: match = 'all-clean mismatch'
            elif len(nek_clean) == 0: match = 'all-dirty mismatch'
            elif len(nek_dirty) > 0 and len(nek_clean) > 0:
                match = 'semi-dirty mismatch'
        else:
            assert False, "Shouldn't be here"
        assert match is not None

        match_bins[match].add(g_hgvs)
        match_counts[match] += len(hmk)

        msg = "{g_hgvs}: {match}; keys: {ho}/{eqk}/{hmk}/{mo}".format(g_hgvs=g_hgvs,
                                                                      match=match,
                                                                      ho=len(hok),
                                                                      eqk=len(eqk),
                                                                      hmk=len(hmk),
                                                                      mo=len(mok))
        if eqk:
            matches = [(h_cs_d[k], m_cs_d[k]) for k in eqk]
            msg += "\n  matches:" + '; '.join(str(mm) for mm in matches)
        if nek:
            if nek_clean:
                mismatches = [(h_cs_d[k], m_cs_d[k]) for k in nek_clean]
                msg += "\n  clean mismatches:" + '; '.join(str(mm) for mm in mismatches)
            if nek_dirty:
                mismatches = [(h_cs_d[k], m_cs_d[k]) for k in nek_dirty]
                msg += "\n  dirty mismatches:" + '; '.join(str(mm) for mm in mismatches)
        print(msg)

    match_keys = [
        'two-sided empty',
        'one-sided empty (hgvs)',
        'one-sided empty (mutalyzer)',
        'perfect match',
        'intersection match',
        'all-clean mismatch',
        'all-dirty mismatch',
        'semi-dirty mismatch',
    ]

    for match in match_keys:
        print("{n:6d} {mc:6d} {match}".format(n=len(match_bins[match]), mc=match_counts[match], match=match))
    print("{n:6d} genomic variants compared".format(n=len(g_hgvs_keys)))
    print("{n:6d} transcript variants compared".format(n=hmk_tot))

    print("{n:6d} transcripts not in hgvs ({acs})".format(n=len(h_missing), acs=','.join(h_missing)))
    print("{n:6d} transcripts not in mutalyzer ({acs})".format(n=len(m_missing), acs=','.join(m_missing)))
