#!/usr/bin/env python
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from future import standard_library
standard_library.install_hooks()

import sys
import familyanalyzer as fa

def handle_args():
    import argparse

    description = 'Analyze Hierarchical OrthoXML families.'

    help_messages = {
        'xreftag': ("xref tag of genes to report. OrthoXML allows to "
                    "store multiple ids and xref annotations per gene "
                    "as attributes in the species section. If not set, "
                    "the internal (purely numerical) ids are reported."),

        'show_levels': ("print the levels and species found in the orthoXML"
                        " file and quit"),

        'taxonomy': ("Taxonomy used to reconstruct intermediate levels. "
                     "Has to be either 'implicit' (default) or a path to "
                     "a file in Newick format. The taxonomy might be "
                     "multifurcating. If set to 'implicit', the "
                     "taxonomy is extracted from the input OrthoXML file. "
                     "The orthoXML level do not have to cover all the "
                     "levels for all families. In order to infer gene losses "
                     "Family-Analyzer needs to infer these skipped levels "
                     "and reconcile each family with the complete taxonomy."),

        'propagate_top': ("propagate taxonomy levels up to the toplevel. As "
                          "an illustration, consider a gene family in a "
                          "eukaryotic analysis that has only mammalian genes. "
                          "Its topmost taxonomic level will therefor be "
                          "'Mammalia' and an ancestral gene was gained at that"
                          " level. However, if '--propagete-top' is set, the "
                          "family is assumed to have already be present in the"
                          " topmost taxonomic level, i.e. Eukaryota in this "
                          "example, and non-mammalian species have all lost"
                          " this gene."),

        'show_taxonomy': 'write the taxonomy used to standard output.',

        'store_augmented_xml': ("filename to which the input orthoxml file "
                                "with augmented annotations is written. The "
                                "augmented annotations include for example the "
                                "additional taxonomic levels of orthologGroup "
                                "and unique HOG IDs."),

        'add_singletons': ("Take singletons - genes from the input set that "
                           "weren't assigned to orthologous groups and "
                           "appear as novel gains in the taxonomy leaves - "
                           "and add them to the xml as single-member "
                           "orthologGroups"),

        'compare_second_level': ("Compare secondary level with primary one, "
                                 "i.e. report what happend between the "
                                 "secondary and primary level to the individual"
                                 " histories. Note that the Second level needs"
                                 " to be younger than the primary."),

        'orthoxml': "path to orthoxml file to be analyzed",

        'level': "taxonomic level at which analysis should be done",

        'species': ("(list of) species to be analyzed. "
                    "Note that only genes of the selected species are "
                    "reported. In order for the output to make sense, "
                    "the selected species all must be part of the lineages "
                    "specified in 'level' (and --compare_second_level)."),

        'gene_trees': ('Output extra gene family tree information, including '
                       'a gene tree for each family, gene coverage at '
                       'each node, and an annotated species tree')

    }

    parser = argparse.ArgumentParser(prog='FamilyAnalyzer',
                                     description=description)
    parser.add_argument('--xreftag', default=None,
                        help=help_messages['xreftag'])
    parser.add_argument('--show_levels', action='store_true',
                        help=help_messages['show_levels'])
    parser.add_argument('--taxonomy', default='implicit',
                        help=help_messages['taxonomy'])
    parser.add_argument('--propagate_top', action='store_true',
                        help=help_messages['propagate_top'])
    parser.add_argument('--show_taxonomy', action='store_true',
                        help=help_messages['show_taxonomy'])
    parser.add_argument('--store_augmented_xml', default=None,
                        help=help_messages['store_augmented_xml'])
    parser.add_argument('--add_singletons', action='store_true',
                        help=help_messages['add_singletons'])
    parser.add_argument('--compare_second_level', default=None,
                        help=help_messages['compare_second_level'])
    parser.add_argument('--gene_trees', action='store_true',
                        help=help_messages['gene_trees'])
    parser.add_argument('orthoxml', help=help_messages['orthoxml'])
    parser.add_argument('level', help=help_messages['level'])
    parser.add_argument('species', nargs="+", help=help_messages['species'])

    return parser.parse_args()

def main():
    args = handle_args()

    op = fa.OrthoXMLParser(args.orthoxml)
    if args.show_levels:
        print("Species:\n{0}\n\nLevels:\n{1}".format(
              '\n'.join(sorted(list(op.getSpeciesSet()))),
              '\n'.join(sorted(op.getLevels()))))
        sys.exit()
    print("Analyzing {} on taxlevel {}".format(args.orthoxml, args.level))
    print("Species found:")
    print("; ".join(op.getSpeciesSet()))
    print("--> analyzing " + "; ".join(args.species))

    ##################
    # DO CALCULATIONS
    ##################

    if args.taxonomy == "implicit":
        tax = fa.TaxonomyFactory.newTaxonomy(op)
    else:
        from familyanalyzer.taxonomy import NewickTaxonomy
        tax = fa.TaxonomyFactory.newTaxonomy(args.taxonomy)
        if isinstance(tax, NewickTaxonomy):
            tax.annotate_from_orthoxml(op)
            tax.finialize_init()

    if args.show_taxonomy:
        print("Use following taxonomy")
        print(tax)

    # add taxonomy to parser
    op.augmentTaxonomyInfo(tax, args.propagate_top)

    if args.store_augmented_xml is not None:
        op.write(args.store_augmented_xml)

    if args.add_singletons:
        op.augmentSingletons()
        singletons = list()
        for spec in args.species:
            h = op.getFamHistory()
            h.analyzeLevel(spec)
            singletons.extend([fam for fam in h if fam.is_singleton()])

    hist = op.getFamHistory()
    hist.analyzeLevel(args.level)
    if args.compare_second_level is None:
        hist.setXRefTag(args.xreftag)

    else:
        hist2 = op.getFamHistory()
        hist2.analyzeLevel(args.compare_second_level)
        comp = hist.compare(hist2)

    if args.gene_trees:
        newtax = tax.retain(args.species)
        newtax.get_comparisons(op)
        newtax.root = args.level
        gtt = fa.GeneTreeTracer(op, newtax)
        gtt.trace_gene_families(add_genelists=True)

    ###############
    # WRITE RESULTS
    ###############

    hist.write(sys.stdout, speciesFilter=args.species)
    print()

    if args.compare_second_level is not None:
        hist2.write(sys.stdout, speciesFilter=args.species)
        comp.write(sys.stdout)
        print()

    if args.add_singletons:
        print('Singletons (unmatched genes):')
        for fam in singletons:
            fam.write(sys.stdout, None,
                idFormatter=lambda gid: op.mapGeneToXRef(gid, h.XRefTag))
        print()

    if args.gene_trees:
        print('Gene family trees')
        for t in gtt.trees:
            header = 'Family {}'.format(t.root.fam_id)
            print('{1}\n{0}\n{1}\n'.format(header, '='*len(header)))
            print('Tree:\n{}\n'.format(t))
            print('Gene coverage:')
            for node in t:
                print('{}\t{}'.format(node.fam_id, node.taxonomic_level), node.genes)
            print()

        print('Annotated species tree covering {}:'.format(args.species))
        print(newtax.newick())

if __name__ == "__main__":
    sys.exit(main())
