#!/usr/bin/env python

import logging
import sys
import eutils.clientx

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    hgnc = sys.argv[1]
    ec = eutils.clientx.ClientX()

    results = ec.fetch_snps_for_gene(hgnc)

    print('\t'.join(['id', 'HGVSg', 'HGVSc', 'HGVSp']))
    for rs in results:
        hgvs_g = rs.hgvs_genome_tags
        if len(hgvs_g) != 1:
            logger.warn("%s: %d genome tags; skipping" % (rs.rs_id, len(hgvs_g)))
            continue
        hgvs_t = [t for t in rs.hgvs_transcript_tags if t.startswith('NM_')]
        hgvs_p = rs.hgvs_protein_tags
        if len(hgvs_p) > 0:
            if len(hgvs_t) != len(hgvs_p):
                logger.warn("%s: %d transcript tags != %d protein tags; skipping" %
                            (rs.rs_id, len(hgvs_t), len(hgvs_p)))
                continue
            for t, p in zip(hgvs_t, hgvs_p):
                print('\t'.join([rs.rs_id, hgvs_g[0], t, p]))
        else:
            for t in hgvs_t:
                print('\t'.join([rs.rs_id, hgvs_g[0], t, '']))
