#!/usr/bin/env python3

import argparse
import logging
import os
import pickle

import numpy as np

import medicc

logger = logging.getLogger('medicc-main')

parser = argparse.ArgumentParser()
parser.add_argument("input_file",
                    help="a path to the input file")
parser.add_argument("output_dir",
                    help="a path to the output folder")
parser.add_argument("--input-type", "-i", type=str, dest="input_type", default="t",
                    choices=["f", "t", "fasta", "tsv"], required=False,
                    help="Choose the type of input: f for FASTA, t for TSV (default: TSV)")
parser.add_argument("--input-allele-columns", "-a",
                    type=str,
                    dest='input_allele_columns',
                    default='cn_a, cn_b',
                    required=False,
                    help="""Name of the CN columns (comma separated) if using TSV input format (default: 'cn_a, cn_b').
                    This also adjusts the number of alleles considered (min. 1, max. 2).""")
parser.add_argument("--input-chr-separator",
                    type=str,
                    dest='input_chr_separator',
                    default='X',
                    required=False,
                    help='Character used to separate chromosomes in the input data (condensed FASTA only, default: \"X\").')
parser.add_argument("--tree",
                    action="store",
                    dest="user_tree",
                    help="Do not reconstruct tree, use provided tree instead (in newick format) and only perform ancestral reconstruction (default: None).",
                    required=False)
parser.add_argument("--topology-only", "-s",
                    action="store_true",
                    dest="topology_only",
                    help="Output only tree topology, without reconstructing ancestors (default: False).",
                    required=False)
parser.add_argument("--normal-name", "-n",
                    default="diploid",
                    type=str,
                    dest="normal_name",
                    help="""ID of the sample to be treated as the normal sample. 
                    Trees are rooted at this sample for ancestral reconstruction (default: \"diploid\").
                    If the sample ID is not found, an artificial normal sample of the same name is created 
                    with CN states = 1 for each allele.""",
                    required=False)
parser.add_argument("--exclude-samples", "-x",
                    default=None,
                    type=str,
                    help="Comma separated list of sample IDs to exclude.",
                    required=False)
parser.add_argument("--filter-segment-length",
                    type=str,
                    dest='filter_segment_length',
                    default=None,
                    required=False,
                    help="""Removes segments that are smaller than specified length.""")
parser.add_argument("--bootstrap-method",
                    type=str,
                    dest='bootstrap_method',
                    default='chr-wise',
                    required=False,
                    help="""Bootstrap method. Has to be either 'chr-wise' or 'segment-wise'""")
parser.add_argument("--bootstrap-nr",
                    type=int,
                    dest='bootstrap_nr',
                    default=None,
                    required=False,
                    help="""Number of bootstrap runs to perform""")
parser.add_argument("--prefix", '-p', type=str, dest='prefix', default=None,
                    help='Output prefix to be used (default: input filename).', required=False)
parser.add_argument("--no-wgd", action='store_true', default=False, dest='no_wgd',
                    help='Enable whole-genome doubling events (default: False).', required=False)
parser.add_argument("--no-plot", action='store_true', default=False,
                    help='Disable plotting (default: False).', required=False)
parser.add_argument("--total-copy-numbers",
                    dest="total_copy_numbers",
                    action='store_true',
                    default=False,
                    required=False,
                    help='Run for total copy number data instead of allele-specific data (default: False).')
parser.add_argument("-j", "--n-cores",
                    type=int,
                    dest='n_cores',
                    default=None,
                    required=False,
                    help="""Number of cores to run on""")
parser.add_argument("--chromosomes-bed",
                    type=str,
                    dest='chromosomes_bed',
                    default='default',
                    required=False,
                    help="""BED file for chromosome regions""")
parser.add_argument("--regions-bed",
                    type=str,
                    dest='regions_bed',
                    default='default',
                    required=False,
                    help="""BED file for regions of interests""")
parser.add_argument("-v", "--verbose", action='store_true', default=False,
                    help='Enable verbose output (default: False).', required=False)
parser.add_argument("-vv", "--debug", action='store_true', default=False,
                    help='Enable more verbose output (default: False).', required=False)
parser.add_argument("--silent", action='store_true', default=False,
                    help='Hide all output (default: False).', required=False)
parser.add_argument("--maxcn", type=int, dest='maxcn', default=8,
                    help='Expert option: maximum CN at which the input is capped. Does not change FST.')
parser.add_argument("--prune-weight", type=int, dest='prune_weight', default=0,
                    help='''Expert option: Prune weight in ancestor reconstruction. Values >0 might
                            result in more accurate ancestors but will require more time and memory. Default: 0''')
parser.add_argument("--fst", type=str, dest='fst', default=None,
                    help='Expert option: path to an alternative FST.')
parser.add_argument("--fst-chr-separator", type=str, dest='fst_chr_separator', default='X',
                    help = 'Expert option: character used to separate chromosomes in the FST (default: \"X\").')
args = parser.parse_args()


if args.verbose:
    logging.getLogger('medicc').setLevel(logging.INFO)
if args.debug:
    logging.getLogger('medicc').setLevel(logging.DEBUG)

if args.silent:
    logging.getLogger('medicc').setLevel(logging.CRITICAL)
    logger.setLevel(logging.CRITICAL)

output_dir = args.output_dir
normal_name = args.normal_name 
allele_columns = [x.strip() for x in args.input_allele_columns.split(',')]

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Determine prefix for output files 
if args.prefix is None:
    output_prefix = os.path.basename(os.path.splitext(args.input_file)[0])
else:
    output_prefix = args.prefix

# Read in symbol table and transducers
logger.info("Reading FST.")
fst = medicc.io.read_fst(user_fst=args.fst, 
                         no_wgd=args.no_wgd)

if args.user_tree is not None:
    logger.info("Importing user tree.")
    input_tree = medicc.io.import_tree(tree_file=args.user_tree, normal_name=normal_name)
else:
    input_tree = None

# Load data
logger.info("Reading and parsing input data.")
input_df = medicc.io.read_and_parse_input_data(
    filename=args.input_file,
    normal_name=normal_name,
    input_type=args.input_type.strip(),
    separator=args.input_chr_separator.strip(),
    allele_columns=allele_columns,
    total_copy_numbers=args.total_copy_numbers,
    maxcn=args.maxcn)

if args.filter_segment_length is not None:
    old_size = len(input_df)
    input_df = medicc.io.filter_by_segment_length(input_df, args.filter_segment_length)
    logger.info("Removed input segments smaller than {}bp. Old size: {} -> new size:{}".format(
        int(float(args.filter_segment_length)),
        old_size, 
        len(input_df)))

if args.exclude_samples is not None:
    exclude_samples = np.array([x.strip() for x in args.exclude_samples.split(',')])
    logger.info("Excluding samples {%s}." % ', '.join(exclude_samples))
    input_df = input_df.loc[~np.in1d(input_df.index.get_level_values('sample_id'), exclude_samples), :]

if args.n_cores is not None:
    logger.info("Running on {} cores.".format(args.n_cores))

# Run main method
logger.info("Running main reconstruction routine.")
sample_labels, pairwise_distances, nj_tree, final_tree, output_df, events_df = medicc.main(
    input_df, 
    fst, 
    normal_name, 
    input_tree=input_tree, 
    ancestral_reconstruction=not args.topology_only,
    chr_separator=args.fst_chr_separator.strip(),
    prune_weight=args.prune_weight,
    allele_columns=allele_columns,
    n_cores=args.n_cores)


# Write events table and overlap events with chromosome and oncogenes
if events_df is not None:
    events_df.to_csv(os.path.join(output_dir, output_prefix + "_copynumber_events_df.tsv"),
                     sep='\t', index=True)

    logger.info("Overlapping copy-number events with chromosomes and known oncogenes.")
    overlaps = medicc.core.overlap_events(events_df=events_df, alleles=allele_columns, 
                                          normal_name=normal_name,
                                          chromosome_bed=args.chromosomes_bed, 
                                          regions_bed=args.regions_bed)
    overlaps.to_csv(os.path.join(output_dir, output_prefix + "_events_overlap.tsv"),
                     sep='\t', index=True)

# Output pairwise distance matrices
logger.info("Writing pairwise distance matrices.")
medicc.io.write_pairwise_distances(sample_labels, pairwise_distances, os.path.join(
    output_dir, output_prefix + "_pairwise_distances"))

# Write trees
logger.info("Writing trees.")
# medicc.io.write_tree_files(tree=nj_tree, out_name=os.path.join(
#     output_dir, output_prefix + "_nj_tree"))
medicc.io.write_tree_files(tree=final_tree, out_name=os.path.join(
    output_dir, output_prefix + "_final_tree"))

# Write ouput table
output_df.to_csv(os.path.join(output_dir, output_prefix + "_final_cn_profiles.tsv"), sep='\t')


# Summarize
logger.info("Writing patient summary.")
summary = medicc.summarize_patient(
    final_tree, pairwise_distances.values, sample_labels, normal_name, events_df)
logger.info("Final tree length %d", summary.tree_length)
summary.to_csv(os.path.join(output_dir, output_prefix + "_summary.tsv"),
               index=True, header=False, sep='\t')

# Bootstrap
if args.bootstrap_nr is not None:
    logger.info("Performing {} bootstrap runs (method: {})".format(args.bootstrap_nr, 
                                                                   args.bootstrap_method))
    bootstrap_trees_df, support_tree = medicc.bootstrap.run_bootstrap(input_df, 
                                                                      final_tree,
                                                                      N_bootstrap=args.bootstrap_nr, 
                                                                      method=args.bootstrap_method,
                                                                      normal_name=normal_name,
                                                                      n_cores=args.n_cores)

    logger.info('Writing bootstrap output')
    with open(os.path.join(output_dir, output_prefix + "_bootstrap_trees_df.pickle"), 'wb') as f:
        pickle.dump(bootstrap_trees_df, f)

    medicc.io.write_tree_files(tree=support_tree, out_name=os.path.join(
        output_dir, output_prefix + "_support_tree"), plot_tree=False, draw_ascii=False)
    fig = medicc.plot.plot_tree(support_tree,
                                title='support tree',
                                show_branch_lengths=True,
                                show_branch_support=True)
    fig.savefig(os.path.join(output_dir, output_prefix + '_support_tree.pdf'), bbox_inches='tight')
else:
    support_tree = None

# Plot CN tracks
if not args.no_plot:
    logger.info("Plotting CN profiles.")
    p = medicc.plot.plot_cn_profiles(
        output_df, 
        input_tree=support_tree if support_tree is not None else final_tree,
        title=output_prefix, 
        normal_name=normal_name,
        allele_columns=allele_columns,
        show_branch_support=support_tree is not None,
        label_func=None)
    p.savefig(os.path.join(output_dir, output_prefix + '_cn_profiles.pdf'))
