#!/usr/bin/env python3

"""
NanoVar

This is the main executable file of the program NanoVar.

Copyright (C) 2021 Tham Cheng Yong

NanoVar is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

NanoVar is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with NanoVar.  If not, see <https://www.gnu.org/licenses/>.
"""


__author__ = 'CY Tham'
import os
import sys
import time
import pysam
import logging
import nanovar
import random
import pickle
import threading
from datetime import datetime
from nanovar import __version__, input_parser, gzip_check, fastx_valid, bed_valid, check_exe, check_index
from progress.spinner import Spinner


def main():
    # Parse arguments
    args = input_parser()
    file_path = args.input
    ref_path = args.ref
    ref_name = os.path.basename(ref_path).rsplit('.', 1)[0]
    wk_dir = args.dir
    data_type = args.data_type
    genome_filter = args.filter_bed
    cnv = args.cnv
    minlen = args.minlen
    splitpct = args.splitpct
    minalign = args.minalign
    mincov = args.mincov
    buff = args.buffer
    score_threshold = args.score
    homo_t = args.homo
    het_t = args.hetero
    threads = args.threads
    debug = args.debug
    quiet = args.quiet
    force = args.force
    model_path = args.model
    filter_bed_dir = os.path.join(os.path.dirname(nanovar.__file__), 'gaps')
    ref_dir = os.path.join(os.path.dirname(nanovar.__file__), 'ref')
    mm = args.mm
    st = args.st
    mdb = args.mdb
    wmk = args.wmk
    hsb = args.hsb
    pickle_bool = args.pickle
    archivefasta = args.archivefasta
    blastout = args.blastout

    # Check model file
    if model_path is None:
        if data_type == 'ont':
            model_path = os.path.join(os.path.dirname(nanovar.__file__), 'model',
                                      'ANN.E100B400L3N12-5D0.4-0.2SGDsee11_het_gup_v1.h5')
        elif data_type == 'pacbio-clr':
            model_path = os.path.join(os.path.dirname(nanovar.__file__), 'model',
                                      'ANN.E100B400L3N12-5D0.4-0.2SGDsee11_het_clr_v1.h5')
        elif data_type == 'pacbio-ccs':
            model_path = os.path.join(os.path.dirname(nanovar.__file__), 'model',
                                      'ANN.E100B400L3N12-5D0.4-0.2SGDsee11_het_ccs_v1.h5')
        else:
            logging.critical("Error: Invalid data type given '%s'" % data_type)
            raise Exception("Error: Invalid data type given '%s'" % data_type)
    else:
        if data_type not in ['ont', 'pacbio-clr', 'pacbio-ccs']:
            logging.info("Invalid data type given '%s', but irrelevant due to custom-built model use." % data_type)

    # Check homo_t > het_t
    if homo_t <= het_t:
        raise Exception("Error: --homo threshold %s is less than or equal to --hetero threshold %s" % (str(homo_t), str(het_t)))

    # Check for required executables
    mm = check_exe(mm, 'minimap2')
    st = check_exe(st, 'samtools')
    mdb = check_exe(mdb, 'makeblastdb')
    wmk = check_exe(wmk, 'windowmasker')
    hsb = check_exe(hsb, 'hs-blastn')

    # If additional CNV mode
    hg38_size_dict = {}
    if cnv:
        # Check CytoCAD
        try:
            import cytocad
            data_dir = os.path.join(os.path.dirname(cytocad.__file__), 'data')
            hg38_size_path = os.path.join(data_dir, 'hg38_sizes_main.bed')
            with open(hg38_size_path) as f:
                for bed in f:
                    line = bed.split('\t')
                    hg38_size_dict[line[0]] = int(line[2])
        except ImportError:
            raise Exception("Error: cytocad module not found, please install CytoCAD - https://github.com/cytham/cytocad")
        # Check tagore and rsvg
        _ = check_exe(None, 'tagore')
        _ = check_exe(None, 'rsvg-convert')
        # Check hg38 build input
        if cnv != 'hg38':
            raise Exception("Error: Reference genome build %s is not recognised. CytoCAD only supports build hg38." % cnv)

    # Observe verbosity
    if quiet:
        sys.stdout = open(os.devnull, 'w')

    # Assign threads
    spin_switch, threads_mm, threads_bt, threads_index = assign_threads(force, threads, quiet, ref_path, wk_dir, ref_name)

    # spin_switch = False

    # Print initiation message
    now = datetime.now()
    now_str = now.strftime("[%d/%m/%Y %H:%M:%S]")
    print(now_str, "- NanoVar started")
    if spin_switch:
        task = TaskProgress()
        msg = 'Checking integrity of input files'
        spinner = Spinner(msg + ' - ')
        thread_spin = threading.Thread(target=task.run, args=(spinner,))
        thread_spin.setDaemon(True)
        thread_spin.start()
    else:
        print('Checking integrity of input files -')
        task = ''
        thread_spin = ''

    # Setup working directory
    if not os.path.exists(wk_dir):
        os.makedirs(wk_dir)
    if not os.path.exists(os.path.join(wk_dir, 'fig')):
        os.makedirs(os.path.join(wk_dir, 'fig'))

    # Setup up logging
    log_file = os.path.join(wk_dir, 'NanoVar-{:%d%m%y-%H%M}.log'.format(datetime.now()))
    logging.basicConfig(filename=log_file, level=logging.DEBUG, format='[%(asctime)s] - %(levelname)s - %(message)s',
                        datefmt='%d/%m/%Y %H:%M:%S')
    logging.info('Initialize NanoVar log file')
    logging.info('Version: NanoVar-%s' % __version__)
    logging.info('Command: %s' % ' '.join(sys.argv))

    # Detect read or map file
    filename = os.path.basename(file_path)
    read_suffix = ['.fa', '.fq', '.fasta', '.fastq', '.fa.gzip', '.fq.gzip', '.fa.gz', '.fq.gz', '.fasta.gz', '.fastq.gz']
    bam_suffix = '.bam'
    contig_list = []
    if any(filename.lower().endswith(s) for s in read_suffix):
        input_name = os.path.basename(file_path).rsplit('.f', 1)[0]
        input_type = 'raw'
        # Test gzip compression and validates read file
        if gzip_check(file_path):
            # read_para = "<(zcat " + file_path + ")"
            fastx_check = fastx_valid(file_path, "gz")
        else:
            # read_para = "<(cat " + file_path + ")"
            fastx_check = fastx_valid(file_path, "txt")
        if fastx_check[0] == "Fail":
            logging.critical("Error: Input FASTQ/FASTA file is corrupted around line %s +/- 4" % str(fastx_check[1]))
            raise Exception("Error: Input FASTQ/FASTA file is corrupted around line %s +/- 4" % str(fastx_check[1]))
        elif fastx_check[0] == "Pass":
            logging.debug("Input FASTQ/FASTA file passed")
    elif filename.lower().endswith(bam_suffix):
        save = pysam.set_verbosity(0)  # Suppress BAM index missing warning
        sam = pysam.AlignmentFile(file_path, "rb")
        pysam.set_verbosity(save)  # Revert verbosity level
        try:
            assert sam.is_bam, "Error: Input BAM file is not a BAM file."
            input_name = os.path.basename(file_path).rsplit('.bam', 1)[0]
            input_type = 'bam'
            fastx_check = []
            # Get BAM contigs from header
            header = sam.header.to_dict()
            for h in header['SQ']:
                contig_list.append(h['SN'])
        except AssertionError:
            logging.critical("Error: Input BAM file is not a BAM file.")
            raise Exception("Error: Input BAM file is not a BAM file.")
    else:
        logging.critical("Error: Input file is not recognised, please ensure file suffix has '.fa' or '.fq' or '.bam'")
        raise Exception("Error: Input file is not recognised, please ensure file suffix has '.fa' or '.fq' or '.bam'")

    if gzip_check(ref_path):
        logging.critical("Error: Input reference file is gzipped, please unzip it")
        raise Exception("Error: Input reference file is gzipped, please unzip it")

    # Logging config info
    logging.info('Input file: %s' % file_path)
    logging.info('Read type: %s' % data_type)
    logging.info('Reference genome: %s' % ref_path)
    logging.info('Working directory: %s' % wk_dir)
    logging.info('Model: %s' % model_path)
    logging.info('Filter file: %s' % genome_filter)
    logging.info('Minimum number of reads for calling a breakend: %s' % str(mincov))
    logging.info('Minimum SV len: %s' % str(minlen))
    logging.info('Mapping percent for split-read: %s' % str(splitpct))
    logging.info('Length buffer for clustering: %s' % str(buff))
    logging.info('Score threshold: %s' % str(score_threshold))
    logging.info('Homozygous read ratio threshold: %s' % str(homo_t))
    logging.info('Heterozygous read ratio threshold: %s' % str(het_t))
    logging.info('Number of threads: %s\n' % str(threads))
    if input_type == 'raw':
        logging.info('Total number of reads in FASTQ/FASTA: %s\n' % str(fastx_check[1]))
    elif input_type == 'bam':
        logging.info('Total number of reads in FASTQ/FASTA: -\n')
    logging.info('NanoVar started')

    from Bio import SeqIO
    from collections import OrderedDict
    # Process reference genome
    contig_len_dict = OrderedDict()
    total_gsize = 0
    for seq_record in SeqIO.parse(ref_path, "fasta"):
        contig_len_dict[seq_record.id] = len(seq_record)
        total_gsize += len(seq_record)

    # Check if contig sizes from BAM is same as hg38 if CNV mode
    if cnv:
        for chrom in hg38_size_dict:
            if chrom in contig_len_dict:
                if contig_len_dict[chrom] != hg38_size_dict[chrom]:
                    raise Exception("Error: Contig %s in BAM (%i) has different length in hg38 genome assembly (%i)" %
                                    (chrom, contig_len_dict[chrom], hg38_size_dict[chrom]))
            else:
                raise Exception("Error: hg38 contig %s is absent in BAM" % chrom)

    # Check BAM contigs in reference genome
    if input_type == 'bam':
        for c in contig_list:
            if c not in contig_len_dict:
                logging.critical("Error: Contig %s in BAM is absent in reference genome" % c)
                raise Exception("Error: Contig %s in BAM is absent in reference genome" % c)

    # Check contig id for invalid symbols
    contig_omit = checkcontignames(contig_len_dict)

    # Validate filter BED file
    if genome_filter is not None:
        if genome_filter == 'hg38':
            filter_path = os.path.join(filter_bed_dir, genome_filter + '_curated_filter_main.bed')
        elif genome_filter in ('hg19', 'mm10'):
            filter_path = os.path.join(filter_bed_dir, genome_filter + '_filter-B400.bed')
        else:
            filter_path = genome_filter
        if os.path.isfile(filter_path):
            if bed_valid(filter_path, contig_len_dict):
                logging.debug("Genome filter BED passed")
        else:
            logging.critical("Error: Genome filter BED %s is not found" % filter_path)
            raise Exception("Error: Genome filter BED %s is not found" % filter_path)
    else:
        filter_path = genome_filter

    # Update progress
    if spin_switch:
        task.endspin()
        thread_spin.join()
        print('')
        if len(contig_omit) > 0:
            print('Warning: The following contig id(s) contains invalid symbols [ ~ : - ] %s. Reads mapping to these contig(s) '
                  'will be ignored.' % ', '.join(['>' + x for x in contig_omit]))
            logging.warning('Warning: The following contig id(s) contains invalid symbols [ ~ : - ] %s. Reads mapping to these '
                            'contig(s) will be ignored.' % ', '.join(['>' + x for x in contig_omit]))
        task = TaskProgress()
        msg = 'Indexing genome and aligning reads'
        spinner = Spinner(msg + ' - ')
        thread_spin = threading.Thread(target=task.run, args=(spinner,))
        thread_spin.setDaemon(True)
        thread_spin.start()
    else:
        if len(contig_omit) > 0:
            print('Warning: The following contig id(s) contains invalid symbols [ ~ : - ] %s. Reads mapping to these contig(s) '
                  'will be ignored.' % ', '.join(['>' + x for x in contig_omit]))
            logging.warning('Warning: The following contig id(s) contains invalid symbols [ ~ : - ] %s. Reads mapping to these '
                            'contig(s) will be ignored.' % ', '.join(['>' + x for x in contig_omit]))
        print('Indexing genome and aligning reads -')

    # Pre-indexing
    from nanovar.nv_align import make_index, align_mm, align_hsb
    if threads_index == 1:
        indexing = threading.Thread(target=make_index, args=(force, ref_path, wk_dir, ref_name, mdb, wmk, hsb))
        indexing.daemon = True
        indexing.start()
    else:
        indexing = ''

    # Aligning using minimap2
    if input_type == 'raw':
        logging.info('Read alignment using minimap2')
        mma = align_mm(ref_path, file_path, wk_dir, input_name, ref_name, threads_mm, mm, data_type, st)
        bam_path = mma[1]
    elif input_type == 'bam':
        logging.info('Input BAM file, skipping minimap2 alignment')
        mma = ['-', '']
        bam_path = file_path
    else:
        raise Exception("Error: Internal error, input_type: %s unknown" % input_type)

    # Update progress
    if spin_switch:
        task.endspin()
        thread_spin.join()
        print('')
        task = TaskProgress()
        msg = 'Analyzing read alignments and detecting SVs'
        spinner = Spinner(msg + ' - ')
        thread_spin = threading.Thread(target=task.run, args=(spinner,))
        thread_spin.setDaemon(True)
        thread_spin.start()
    else:
        print('Analyzing read alignments and detecting SVs -')

    # Parse and detect SVs
    logging.getLogger("matplotlib").setLevel(logging.WARNING)
    from nanovar.nv_characterize import VariantDetect
    logging.info('Parsing BAM and detecting SVs')
    run = VariantDetect(wk_dir, bam_path, splitpct, minalign, filter_path, minlen, buff, model_path,
                        total_gsize, contig_len_dict, score_threshold, file_path, input_name, ref_path, ref_name, mma[0],
                        mincov, homo_t, het_t, debug, contig_omit, cnv)
    run.bam_parse_detect()
    run.coverage_stats()

    logging.info('Total number of mapped reads: %s\n' % str(run.maps))

    # Update progress
    if spin_switch:
        task.endspin()
        thread_spin.join()
        print('')
        task = TaskProgress()
        msg = 'Clustering SV breakends'
        spinner = Spinner(msg + ' - ')
        thread_spin = threading.Thread(target=task.run, args=(spinner,))
        thread_spin.setDaemon(True)
        thread_spin.start()
    else:
        print('Clustering SV breakends -')

    # SV breakend clustering and extracting INS and INV SVs
    run.cluster_extract()

    # Update progress
    if spin_switch:
        task.endspin()
        thread_spin.join()
        print('')
        task = TaskProgress()
        msg = 'Re-evaluating SVs with BLAST and inferencing'
        spinner = Spinner(msg + ' - ')
        thread_spin = threading.Thread(target=task.run, args=(spinner,))
        thread_spin.setDaemon(True)
        thread_spin.start()
    else:
        print('Re-evaluating SVs with BLAST and inferencing -')

    # Wait for indexing or make index
    if threads_index == 1:
        indexing.join()
    elif threads_index == 0:
        make_index(force, ref_path, wk_dir, ref_name, mdb, wmk, hsb)

    # Run hsblastn on INS and INV reads
    if not blastout:
        hsba = align_hsb(ref_path, wk_dir, ref_name, threads_bt, hsb, debug)
    else:
        hsba = ['blast-command', blastout]
    sub_run = VariantDetect(wk_dir, hsba[1], splitpct, minalign, filter_path, minlen, buff, model_path,
                            total_gsize, contig_len_dict, score_threshold, file_path, input_name, ref_path, ref_name, hsba[0],
                            mincov, homo_t, het_t, debug, contig_omit, cnv)

    # Parsing INS and INV SVs and clustering
    sub_run.rlendict = run.rlendict
    sub_run.seed = run.seed  # Try to prevent repeated read-indexes
    sub_run.parse_detect_hsb()
    logging.info('Parsing BAM and detecting INV and INS SVs')
    run.cluster_nn(add_out=sub_run.total_out)

    # Update progress
    if spin_switch:
        task.endspin()
        thread_spin.join()
        print('')
        task = TaskProgress()
        msg = 'Generating VCF files and report'
        spinner = Spinner(msg + ' - ')
        thread_spin = threading.Thread(target=task.run, args=(spinner,))
        thread_spin.setDaemon(True)
        thread_spin.start()
    else:
        print('Generating VCF files and report -')

    # Analyze TE
    from nanovar.nv_te_analyzer import te_analyzer
    pysam.faidx(os.path.join(wk_dir, 'temp1.fa'))
    index2te = te_analyzer(wk_dir, run.out_nn, run.total_out, sub_run.total_out, score_threshold, ref_dir, hsb, mdb, threads, debug)

    # Write parse2 and cluster intermediate files
    run.write2file(add_out=sub_run.total_out)

    # Pickle objects
    if pickle_bool:
        saveobjects = (wk_dir, bam_path, splitpct, minalign, filter_path, minlen, buff, model_path,
                       total_gsize, contig_len_dict, score_threshold, file_path, input_name, ref_path, ref_name, mma[0],
                       mincov, homo_t, het_t, debug, contig_omit, cnv, run.rlendict, run.total_subdata, run.total_out,
                       run.out_nn, sub_run.total_out)
        with open(os.path.join(wk_dir, "nv_run.pkl"), "wb") as f:
            pickle.dump(saveobjects, f)

    # Delete temporary fasta file
    if not archivefasta:
        f = os.path.join(wk_dir, 'temp1.fa')
        os.remove(f)
        f = os.path.join(wk_dir, 'temp1.fa.fai')
        os.remove(f)

    # Generating VCF and HTML report
    run.vcf_report(index2te)
    logging.info('NanoVar ended')
    now = datetime.now()
    if spin_switch:
        task.endspin()
        thread_spin.join()
        time.sleep(1)
        now_str = now.strftime("\n[%d/%m/%Y %H:%M:%S]")
    else:
        now_str = now.strftime("[%d/%m/%Y %H:%M:%S]")
    print(now_str, "- NanoVar ended")


# Check contig name
def checkcontignames(contig_len_dict):
    contig_omit = {}
    for contig in contig_len_dict:
        if "~" in contig or ":" in contig or "-" in contig:
            contig_omit[contig] = [0, int(contig_len_dict[contig])]
    return contig_omit


# Progress message
class TaskProgress:

    def __init__(self):
        self.prog_spin = True

    def endspin(self):
        self.prog_spin = False

    def run(self, spin):
        while self.prog_spin:
            spin.next()
            sleep()


# Randomize sleep timing
def sleep():
    t = 0.05
    t += t * random.uniform(-0.1, 0.1)
    time.sleep(t)


# Assign thread usage according to number of threads
def assign_threads(force, threads, quiet, ref_path, wk_dir, ref_name):
    if force:
        if threads == 1:
            spin_switch = False
            threads_mm = 1
            threads_bt = 1
            threads_index = 0
        elif threads == 2:
            spin_switch = False
            threads_mm = 1
            threads_bt = 2
            threads_index = 1
        else:  # threads > 2
            if quiet:
                spin_switch = False
                threads_mm = threads - 1
                threads_bt = min(threads, 53)
                threads_index = 1
            else:
                spin_switch = True
                threads_mm = threads - 2
                threads_bt = min(threads - 1, 53)
                threads_index = 1
    else:
        # Check indexes
        index_present = check_index(ref_path, wk_dir, ref_name)
        if threads == 1:
            spin_switch = False
            threads_mm = 1
            threads_bt = 1
            threads_index = 0
        elif threads == 2:
            if index_present:
                if quiet:
                    spin_switch = False
                    threads_mm = 2
                    threads_bt = 2
                    threads_index = 0
                else:
                    spin_switch = True
                    threads_mm = 1
                    threads_bt = 1
                    threads_index = 0
            else:
                spin_switch = False
                threads_mm = 1
                threads_bt = 2
                threads_index = 1
        else:  # threads > 2
            if index_present:
                if quiet:
                    spin_switch = False
                    threads_mm = threads
                    threads_bt = min(threads, 53)
                    threads_index = 0
                else:
                    spin_switch = True
                    threads_mm = threads - 1
                    threads_bt = min(threads - 1, 53)
                    threads_index = 0
            else:
                if quiet:
                    spin_switch = False
                    threads_mm = threads - 1
                    threads_bt = min(threads, 53)
                    threads_index = 1
                else:
                    spin_switch = True
                    threads_mm = threads - 2
                    threads_bt = min(threads - 1, 53)
                    threads_index = 1
    return spin_switch, threads_mm, threads_bt, threads_index


if __name__ == "__main__":
    main()
