#!/usr/bin/env python
threads = 8


import sys
import subprocess as sp
import os
from collections import defaultdict as dd
import argparse

# local files
import isocirc.error_corr as ec
import isocirc.cons_call as cc
import isocirc.cons_align as ca
import isocirc.bam_classify as bc
import isocirc.hcBSJ_fullIso as hf
import isocirc.utils as ut
#import isocirc_plot as pp
import isocirc.basic_stats as bs
from isocirc.__init__ import __program__
from isocirc.__init__ import __version__


bedtools_v = '2.27.0'
minimap2_v = '2.11'

def satisfied_version(res, tool, version):
	v = []
	if tool == 'bedtools':
		v = list(map(int ,res.rsplit()[1].rsplit('v')[1].rsplit('.')))
	elif tool == 'minimap2':
		v = list(map(int, res.rsplit('-')[0].rsplit('.')))
	min_v = list(map(int, version.rsplit('.')))
	if len(v) != len(min_v):
		return False
	for i, j in zip(v, min_v):
		if i > j:
			return True
		elif i < j:
			return False
	return True

def check_depen(bedtools='bedtools', minimap2='minimap2'):
	names = ['bedtools', 'minimap2']
	urls = ['https://bedtools.readthedocs.io/', 'https://github.com/lh3/minimap2']
	ut.err_format_time('check_dependencies', 'Checking dependencies ...')
	for t, n, v, u in zip([bedtools, minimap2], names, [bedtools_v, minimap2_v], urls):
		try:
			r = sp.check_output([t, '--version'])
		except:
			ut.fatal_format_time('check_dependencies', '{} ({}, >= v{}) is not installed.'.format(n, u, v))
		if not isinstance(r, str): r = str(r, 'utf-8') 
		if not satisfied_version(r, n, v):
			ut.fatal_format_time('check_dependencies', '{} ({}) is too old, >= v{} is needed.'.format(n, u, v))
	ut.err_format_time('check_dependencies', 'Checking dependencies done!')



def not_fasta(in_fn):
	if in_fn.endswith('.gz'):
		return True
	with open(in_fn) as in_fp:
		for line in in_fp:
			if line.startswith('@'):
				return True
			elif line.startswith('>'):
				return False
			else:
				ut.fatal_format_time('main', 'Input data is not FASTA/FASTQ file.')


def isocirc_core(args):
	if not os.path.exists(args.out_dir):
		os.makedirs(args.out_dir)
	#if not os.path.exists(args.out_dir + '/plot'):
	#	os.makedirs(args.out_dir + '/plot')
	#plot_out_dir = args.out_dir + '/plot'

	# 0. error correction with short-read
	if args.short_read:
		ut.err_format_time('Error-correction', 'Hybrid error correction using {} ...'.format(args.short_read))
		corr_long = args.out_dir + '/long_corrected.fa'
		ec.error_corr(args.in_long_read, args.short_read, corr_long, args.lordec, args.kmer, args.solid, args.threads)
		ut.err_format_time('Error-correction', 'Hybrid error correction done!')
	else:
		corr_long = args.in_long_read

	cons_fa, cons_info = args.out_dir + '/cons.fa', args.out_dir + '/cons.info'
	# if args.use_tidehunter:
	# 	# 1. generate cons with TideHunter
	# 	ut.err_format_time('TideHunter', 'Finding tandem repeats with TideHunter ...')
	# 	th_out = args.out_dir + '/th.out'
	# 	cc.run_tidehunter(corr_long, cons_fa, cons_info, th_out, tidehunter=args.tidehunter, cons_min_len=args.min_len, cons_min_copy=args.min_copy, cons_min_frac=args.min_frac, threads=args.threads)
	# 	ut.err_format_time('TideHunter', 'Finding tandem repeats with TideHunter done!')
	# else:
	# 1. generate cons with TRF
	ut.err_format_time('Tandem-Repeats-Finder', 'Finding tandem repeats with TRF ...')
	if not_fasta(corr_long):
		ut.exec_cmd(sys.stderr, 'fxtools', '{} qa {} > {} 2> /dev/null'.format(cc.fxtools, corr_long, args.out_dir + '/long_reads.fa'))
		corr_long = args.out_dir + '/long_reads.fa'
	trf_out = args.out_dir + '/trf.out'
	if args.threads > 1:
		cc.run_trf_parall(corr_long, cons_fa, cons_info, trf_out, args.trf, args.match, args.mismatch, args.indel, args.match_frac, args.indel_frac, args.min_score, args.max_period, args.min_len, args.min_copy, args.min_frac, args.threads)
	else:
		cc.run_trf(corr_long, cons_fa, cons_info, trf_out, args.trf, args.match, args.mismatch, args.indel, args.match_frac, args.indel_frac, args.min_score, args.max_period, args.min_len, args.min_copy, args.min_frac)
	ut.err_format_time('Tandem-Repeats-Finder', 'Finding tandem repeats with TRF done!')

	# 2. extract consensus and align it to genome
	ut.err_format_time('Mapping', 'Mapping consensus sequence to genome ...')
	cons_all_sam = cons_fa + '.sam'
	ca.align_cons(cons_fa, args.ref, cons_all_sam, args.minimap2, args.threads)
	ut.err_format_time('Mapping', 'Mapping consensus sequence to genome done!')

	# 3. classify alignment
	ut.err_format_time('Classifying', 'Classifying consensus alignment ...')
	high_bam, low_bam = args.out_dir + '/high.bam', args.out_dir + '/low.bam'
	bc.bam_classify(cons_all_sam, high_bam, low_bam, args.high_max_ratio, args.high_min_ratio, args.high_iden_ratio, args.high_repeat_ratio, args.low_repeat_ratio)
	ut.err_format_time('Classifying', 'Classifying consensus alignment done!')
	# 4. evaluate alignment with annotations
	itst_bed_dict = dd(lambda: '')
	itst_bed_dict['Alu'], itst_bed_dict['allRepeat'] = args.Alu, args.all_repeat

	if os.path.exists(corr_long + '.len'):
		long_len_fn = corr_long + '.len'
	else:
		long_len_fn = args.out_dir + '/' + os.path.basename(corr_long) + '.len'
		if not os.path.exists(long_len_fn):
			ut.exec_cmd(sys.stderr, 'fxtools', '{} lp {} > {} 2> /dev/null'.format(cc.fxtools, corr_long,long_len_fn))

	# ut.err_format_time('Filtering', 'Filtering circRNA reads ...')
	isoform_out, bed_out, stats_out = args.out_dir + '/{}.out'.format(__program__), args.out_dir + '/{}.bed'.format(__program__), args.out_dir + '/{}_stats.out'.format(__program__)
	hf.hcBSJ_fullIso(high_bam, low_bam, long_len_fn, cons_info, cons_fa,
		args.ref, args.gene_anno, args.circRNA_anno, itst_bed_dict, args.bedtools, args.flank_len, args.cano_motif, args.bsj_xid, args.key_bsj_xid, args.min_circ_dis, args.rescue_low,
		args.fsj_xid, args.key_fsj_xid,
		isoform_out, bed_out, stats_out)
	# ut.err_format_time('Filtering', 'Filtering circRNA reads done!')
	# 5. generate stats plot
	# ut.err_format_time('Plotting', 'Generating stats figures ... ')
	# block_max, iso_per_gene_max = 20, 10
	# pp.stats_plot_core(isoform_out, max(args.end_dis, args.site_dis), block_max, iso_per_gene_max, plot_out_dir + '/')
	# ut.err_format_time('Plotting', 'Generating stats figures done!')


def parser_argv():
	# parse command line arguments
	parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="{}: computational pipeline to identify high-confidence BSJs and full-length circRNA isoforms from isoCirc long-read data".format(__program__))
	parser.add_argument("in_long_read", metavar='long.fa/fq', type=str, help='Long read sequencing data generated with isoCirc.')
	parser.add_argument("ref", metavar='ref.fa', type=str, help='Reference genome sequence file.')
	parser.add_argument('gene_anno', metavar='anno.gtf', type=str, help='Gene annotation file in GTF format.')
	parser.add_argument('circRNA_anno', metavar='circRNA.bed/gtf', type=str, help='circRNA database annotation file in BED or GTF format. Use \',\' to separate multiple circRNA annotation files.')
	parser.add_argument('out_dir', metavar='out_dir', type=str, help='Output directory for final result and temporary files.')

	parser.add_argument('-v', '--version', action='version', version=__program__ + ' ' + __version__)

	general_par = parser.add_argument_group('General options')
	# general_par.add_argument('--type', type=str, help='Type of sequencing data: Oxford Nanopore(ont) or Pacific Biosciences (pb).', choices=['ont', 'pb'], default='ont')
	general_par.add_argument('-t', '--threads', type=int, default=threads, help='Number of thread to use.')
	general_par.add_argument('--bedtools', help='Path to bedtools.', default=hf.bedtools)
	general_par.add_argument('--minimap2', help='Path to minimap2.', default=ca.minimap2)

	corr_par = parser.add_argument_group('Hybrid error-correction with short-read data (LoRDEC)')
	corr_par.add_argument("--short-read", metavar='short.fa/fq', type=str, default='', help="Short-read data for error correction. Use \',\' to connect multiple or paired-end short read data.")
	corr_par.add_argument('--lordec', type=str, help='Path to lordec-correct.', default=ec.lordec)
	corr_par.add_argument('--kmer', type=int, help='k-mer size.', default=ec.kmer)
	corr_par.add_argument('--solid', type=int, help='Solid k-mer abundance threshold.', default=ec.solid)

	tr_par = parser.add_argument_group('Consensus calling with Tandem Repeats Finder (TRF))')
	# tr_par.add_argument('-T', '--use-tidehunter', default=False, action='store_true', help='Use TideHunter as tandem repeats detection tool.')
	tr_par.add_argument('--trf', type=str, help='Path to trf program.', default=cc.trf)
	# tr_par.add_argument('--tidehunter', type=str, help='Path to TideHunter.', default=cc.tidehunter)

	tr_par.add_argument('--match', type=int, help='Match score.', default=cc.match)
	tr_par.add_argument('--mismatch', type=int, help='Mismatch penalty.', default=cc.mismatch)
	tr_par.add_argument('--indel', type=int, help='Indel penalty.', default=cc.indel)
	tr_par.add_argument('--match-frac', type=int, help='Match probability.', default=cc.match_frac)
	tr_par.add_argument('--indel-frac', type=int, help='Indel probability.', default=cc.indel_frac)
	tr_par.add_argument('--min-score', type=int, help='Minimum alignment score to report.', default=cc.min_score)
	tr_par.add_argument('--max-period', type=int, help='Maximum period size to report.', default=cc.max_period)

	aln_par = parser.add_argument_group('Filtering and mapping of consensus sequences (minimap2)')
	aln_par.add_argument('--min-len', type=int, help='Minimum consensus length to keep.', default=cc.cons_min_len)
	aln_par.add_argument('--min-copy', type=float, help='Minimum copy number of consensus to keep.', default=cc.cons_min_copy)
	aln_par.add_argument('--min-frac', type=float, help='Minimum fraction of original long read to keep.', default=cc.cons_min_frac)

	# aln_par.add_argument('-f', '--do-classify', default=False, action='store_true', help="Classify circRNA alignment into high-quality and low-quality.")
	aln_par.add_argument('--high-max-ratio', type=float, help='Maximum mappedLen / consLen ratio for high-quality alignment.', default=bc.high_max_ratio)
	aln_par.add_argument('--high-min-ratio', type=float, help='Minimum mappedLen /consLen ratio for high-quality alignment.', default=bc.high_min_ratio)
	aln_par.add_argument('--high-iden-ratio', type=float, help='Minimum identicalBases/ consLen ratio for high-quality alignment.', default=bc.high_iden_ratio)
	aln_par.add_argument('--high-repeat-ratio', type=float, help='Maximum mappedLen / consLen ratio for high-quality self-tandem consensus.', default=bc.high_repeat_ratio)
	aln_par.add_argument('--low-repeat-ratio', type=float, help='Minimum mappedLen / consLen ratio for low-quality self-tandem alignment.', default=bc.low_repeat_ratio)


	circ_par = parser.add_argument_group('Identifying high-confidence BSJs and full-length circRNAs')
	circ_par.add_argument('--cano-motif', type=str, default='GT/AG', help='Canonical back-splice motif (GT/AG or all three motifs: GT/AG, GC/AG, AT/AC).', choices=['GT/AG', 'all'])
	circ_par.add_argument('--bsj-xid', type=int, default=1, help='Maximum allowed mis/ins/del for 20-bp exonic sequence flanking the BSJ (10-bp each side).')
	circ_par.add_argument('--key-bsj-xid', type=int, default=0, help='Maximum allowed mis/ins/del for 4-bp exonic sequence flanking the BSJ (2-bp each side).')
	circ_par.add_argument('--min-circ-dis', type=int, default=150, help='Minimum distance between the genomic coordinates of the two back-splice sites.')
	circ_par.add_argument('--rescue-low', default=False, action='store_true', help='Use high mapping quality reads to rescue low mapping quality reads.')

	circ_par.add_argument('--fsj-xid', type=int, default=1, help='Maximum allowed mis/ins/del for 20-bp exonic sequence flanking the FSJ (10-bp each side).')
	circ_par.add_argument('--key-fsj-xid', type=int, default=0, help='Maximum allowed mis/ins/del for 4-bp exonic sequence flanking the FSJ (2-bp each side).')

	eval_par = parser.add_argument_group('Other options')
	eval_par.add_argument('--Alu', type=str, default='', help='Alu repetitive element annotation in BED format. ')
	eval_par.add_argument('--flank-len', type=int, default=hf.flank_len, help='Length of upstream and downstream flanking sequence to search for Alu.')
	eval_par.add_argument('--all-repeat', type=str, default='', help='All repetitive element annotation in BED format.')

	# eval_par.add_argument('-s', '--site-dis', type=int, default=hf.site_dis, help='Maximum allowed distance between circRNA internal-splice-site and annoated splice-site.')
	# eval_par.add_argument('-S', '--end-dis', type=int, default=hf.end_dis, help='Maximum allowed distance between circRNA back-splice-site and annoated splice-site.')
	return parser.parse_args()


def main():
	args = parser_argv()
	check_depen(args.bedtools, args.minimap2)
	isocirc_core(args)

if __name__ == '__main__':
	main()
