#!/usr/bin/env python3
"""Group variants from hgvs4variation.txt.gz, one row per VariationID

For example:

A2M 2   18171   33210   coding  na  NM_000014.4:c.2998A>G   c.2998A>G   NP_000005.2:p.Ile1000Val    p.Ile1000Val
A2M 2   18171   33210   genomic -   NG_011717.1:g.41291A>G  g.41291A>G  -   -
A2M 2   18171   33210   genomic GRCh38  NC_000012.12:g.9079672T>C   g.9079672T>C    -   -
A2M 2   18171   33210   genomic GRCh37  NC_000012.11:g.9232268T>C   g.9232268T>C    -   -

becomes

A2M 18171   NC_000012.11:g.9232268T>C NC_000012.12:g.9079672T>C NG_011717.1:g.41291A>G NM_000014.4:c.2998A>G NP_000005.2:p.Ile1000Val

"""

import collections
import csv
import gzip
import io
import itertools
import logging
import re
import sys

ncbi_acs = ('NC', 'NG', 'NM', 'NP', 'NR')


def open_hgvs4variation(fn):
    fh = gzip.open(fn, "rb")
    while not fh.peek(7).startswith(b"#Symbol"):
        fh.readline()
    fh.read(1)            # fh now points to first char of header
    return csv.DictReader(io.TextIOWrapper(fh), delimiter="\t")


def read_variation_groups(fn):
    """open fn, yield sets of variants with same VariationID"""
    rdr = open_hgvs4variation(fn)
    blocks = []
    for rec in rdr:
        if (blocks and blocks[0]["VariationID"] != rec["VariationID"]):
            yield blocks
            blocks = []
        blocks.append(rec)
    if blocks:
        yield blocks

def collect_sequence_variants(vg):
    trunc_del_re = re.compile("del$")

    vstrs = [r["NucleotideExpression"] for r in vg] + [r["ProteinExpression"] for r in vg]
    vstrs = filter(lambda x: x != "-", vstrs)
    vstrs = filter(lambda x: not trunc_del_re.search(x), vstrs)
    vstrs = list(vstrs)

    binned_vstrs = {p: list(set([vstr for vstr in vstrs if vstr.startswith(p)])) for p in ncbi_acs}
    return binned_vstrs


if __name__ == "__main__":
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)

    fieldnames = "gene variation_id hgvs_variants".split()
    writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames, delimiter="\t")
    writer.writeheader()

    for vg in read_variation_groups(sys.argv[1]):
        if vg[0]["Symbol"] == "-":
            continue

        binned_vstrs = collect_sequence_variants(vg)

        n_tx_vars = len(binned_vstrs["NM"]) + len(binned_vstrs["NR"])
        if n_tx_vars == 0:
            continue

        nc_ac_counts = collections.Counter(vstr.split(":")[0] for vstr in binned_vstrs["NC"])
        dup_acs = [ac for ac,cnt in nc_ac_counts.items() if cnt > 1]
        if dup_acs:
            logger.warn("Gene {vg[0][Symbol]}, VariationID {vg[0][VariationID]}: multiple genomic loci on accessions {dup_acs}".format(
                vg=vg, dup_acs=", ".join(dup_acs)))

        hgvs_variants = " ".join(itertools.chain.from_iterable(
            sorted(binned_vstrs[k]) for k in ncbi_acs))

        writer.writerow({
            "gene": vg[0]["Symbol"],
            "variation_id": vg[0]["VariationID"],
            "hgvs_variants": hgvs_variants})
