#! /usr/bin/env python
from __future__ import print_function

import os
import sys
import os
from time import time
import itertools
from itertools import islice, chain
from struct import *
import shutil

import argparse
import errno
import math

# import pickle
import dill as pickle 
import gffutils
import pysam

import multiprocessing as mp
from multiprocessing import Pool
from collections import defaultdict


# from modules import create_splice_graph as splice_graph
# from modules import graph_chainer 

from modules import create_augmented_gene as augmented_gene 
from modules import mem_wrapper 
from modules import colinear_solver 
from modules import help_functions
from modules import classify_read_with_mams
from modules import classify_alignment2
from modules import sam_output
from modules import align
from modules import prefilter_genomic_reads


def load_reference(args):
    refs = {acc : seq for acc, (seq, _) in help_functions.readfq(open(args.ref,"r"))}
    refs_lengths = { acc : len(seq) for acc, seq in refs.items()} 
    return refs, refs_lengths

def prep_splicing(args, refs_lengths):
    database = os.path.join(args.outfolder,'database.db')
    if os.path.isfile(database):
        print("Database found in directory using this one.")
        print("If you want to recreate the database, please remove the file: {0}".format(database))
        print()
        db = gffutils.FeatureDB(database, keep_order=True)
        # sys.exit()
    elif not args.disable_infer:
        fn = gffutils.example_filename(args.gtf)
        db = gffutils.create_db(fn, dbfn=database, force=True, keep_order=True, merge_strategy='merge', 
                                sort_attribute_values=True)
        db = gffutils.FeatureDB(database, keep_order=True)
    else:
        fn = gffutils.example_filename(args.gtf)
        db = gffutils.create_db(fn, dbfn=database, force=True, keep_order=True, merge_strategy='merge', 
                                sort_attribute_values=True, disable_infer_genes=True, disable_infer_transcripts=True)
        db = gffutils.FeatureDB(database, keep_order=True)

    segment_to_ref, parts_to_segments, splices_to_transcripts, \
    transcripts_to_splices, all_splice_pairs_annotations, \
    all_splice_sites_annotations, segment_id_to_choordinates, \
    segment_to_gene, gene_to_small_segments, flank_choordinates, \
    max_intron_chr, exon_choordinates_to_id, chr_to_id, id_to_chr, tiling_structures = augmented_gene.create_graph_from_exon_parts(db, args.flank_size, args.small_exon_threshold, args.min_segm, refs_lengths)

    # dump to pickle here! Both graph and reference seqs
    # help_functions.pickle_dump(args, genes_to_ref, 'genes_to_ref.pickle')
    help_functions.pickle_dump(args, segment_to_ref, 'segment_to_ref.pickle')
    help_functions.pickle_dump(args, splices_to_transcripts, 'splices_to_transcripts.pickle')
    help_functions.pickle_dump(args, transcripts_to_splices, 'transcripts_to_splices.pickle')
    help_functions.pickle_dump(args, parts_to_segments, 'parts_to_segments.pickle')
    help_functions.pickle_dump(args, all_splice_pairs_annotations, 'all_splice_pairs_annotations.pickle')
    help_functions.pickle_dump(args, all_splice_sites_annotations, 'all_splice_sites_annotations.pickle')
    help_functions.pickle_dump(args, segment_id_to_choordinates, 'segment_id_to_choordinates.pickle')
    help_functions.pickle_dump(args, segment_to_gene, 'segment_to_gene.pickle')
    help_functions.pickle_dump(args, gene_to_small_segments, 'gene_to_small_segments.pickle')
    help_functions.pickle_dump(args, flank_choordinates, 'flank_choordinates.pickle')
    help_functions.pickle_dump(args, max_intron_chr, 'max_intron_chr.pickle')
    help_functions.pickle_dump(args, exon_choordinates_to_id, 'exon_choordinates_to_id.pickle')
    help_functions.pickle_dump(args, chr_to_id, 'chr_to_id.pickle')
    help_functions.pickle_dump(args, id_to_chr, 'id_to_chr.pickle')

    tiling_segment_id_to_choordinates, tiling_segment_to_gene, tiling_segment_to_ref, tiling_parts_to_segments, tiling_gene_to_small_segments = tiling_structures # unpacking tiling structures
    help_functions.pickle_dump(args, tiling_segment_id_to_choordinates, 'tiling_segment_id_to_choordinates.pickle')
    help_functions.pickle_dump(args, tiling_segment_to_gene, 'tiling_segment_to_gene.pickle')
    help_functions.pickle_dump(args, tiling_segment_to_ref, 'tiling_segment_to_ref.pickle')
    help_functions.pickle_dump(args, tiling_parts_to_segments, 'tiling_parts_to_segments.pickle')
    help_functions.pickle_dump(args, tiling_gene_to_small_segments, 'tiling_gene_to_small_segments.pickle')


def prep_seqs(args, refs, refs_lengths):
    parts_to_segments = help_functions.pickle_load( os.path.join(args.outfolder, 'parts_to_segments.pickle') )
    segment_id_to_choordinates = help_functions.pickle_load( os.path.join(args.outfolder, 'segment_id_to_choordinates.pickle') )
    segment_to_ref = help_functions.pickle_load( os.path.join(args.outfolder, 'segment_to_ref.pickle') )
    flank_choordinates = help_functions.pickle_load( os.path.join(args.outfolder, 'flank_choordinates.pickle') )
    exon_choordinates_to_id = help_functions.pickle_load( os.path.join(args.outfolder, 'exon_choordinates_to_id.pickle') )
    chr_to_id = help_functions.pickle_load( os.path.join(args.outfolder, 'chr_to_id.pickle') )
    id_to_chr = help_functions.pickle_load( os.path.join(args.outfolder, 'id_to_chr.pickle') )

    tiling_parts_to_segments = help_functions.pickle_load( os.path.join(args.outfolder, 'tiling_parts_to_segments.pickle') )
    tiling_segment_id_to_choordinates = help_functions.pickle_load( os.path.join(args.outfolder, 'tiling_segment_id_to_choordinates.pickle') )
    tiling_segment_to_ref = help_functions.pickle_load( os.path.join(args.outfolder, 'tiling_segment_to_ref.pickle') )
    
    print( "Number of ref seqs in gff:", len(parts_to_segments.keys()))

    refs_id = {}

    not_in_annot = set()
    for acc, seq in refs.items():
        if acc not in chr_to_id:
            not_in_annot.add(acc)
        else:
            acc_id = chr_to_id[acc]
            refs_id[acc_id] = seq

    refs_id_lengths = { acc_id : len(seq) for acc_id, seq in refs_id.items()} 
    help_functions.pickle_dump(args, refs_id_lengths, 'refs_id_lengths.pickle')
    help_functions.pickle_dump(args, refs_lengths, 'refs_lengths.pickle')

    print( "Number of ref seqs in fasta:", len(refs.keys()))

    not_in_ref = set(chr_to_id.keys()) - set(refs.keys())
    if not_in_ref:
        print("Warning: Detected {0} sequences that are in annotation but not in reference fasta. Using only sequences present in fasta. The following annotations cannot be detected:\n".format(len(not_in_ref)))
        for s in not_in_ref:
            print(s)

    if not_in_annot:
        print("Warning: Detected {0} sequences in reference fasta that are not in annotation. Will not be able to infer annotations from alignments on the following reference sequences:\n".format(len(not_in_annot)))
        for s in not_in_annot:
            print(s, "with length:{0}".format(len(refs[s])))
    # ref_part_sequences, ref_flank_sequences = augmented_gene.get_part_sequences_from_choordinates(parts_to_segments, flank_choordinates, refs_id)
    ref_part_sequences = augmented_gene.get_sequences_from_choordinates(parts_to_segments, refs_id)
    ref_flank_sequences = augmented_gene.get_sequences_from_choordinates(flank_choordinates, refs_id)


    augmented_gene.mask_abundant_kmers(ref_part_sequences, args.min_mem, args.mask_threshold)
    augmented_gene.mask_abundant_kmers(ref_flank_sequences, args.min_mem, args.mask_threshold)
    # print([unpack('LLL',t) for t in ref_flank_sequences.keys()])
    ref_part_sequences = help_functions.update_nested(ref_part_sequences, ref_flank_sequences)
    ref_segment_sequences = augmented_gene.get_sequences_from_choordinates(segment_id_to_choordinates, refs_id)
    # ref_flank_sequences = augmented_gene.get_sequences_from_choordinates(flank_choordinates, refs_id)
    ref_exon_sequences = augmented_gene.get_sequences_from_choordinates(exon_choordinates_to_id, refs_id)
    help_functions.pickle_dump(args, segment_id_to_choordinates, 'segment_id_to_choordinates.pickle')
    help_functions.pickle_dump(args, ref_part_sequences, 'ref_part_sequences.pickle')
    help_functions.pickle_dump(args, ref_segment_sequences, 'ref_segment_sequences.pickle')
    help_functions.pickle_dump(args, ref_flank_sequences, 'ref_flank_sequences.pickle')
    help_functions.pickle_dump(args, ref_exon_sequences, 'ref_exon_sequences.pickle')

    tiling_ref_segment_sequences = augmented_gene.get_sequences_from_choordinates(tiling_segment_id_to_choordinates, refs_id)
    help_functions.pickle_dump(args, tiling_ref_segment_sequences, 'tiling_ref_segment_sequences.pickle')



def batch(dictionary, size, batch_type):
    # if batch_type == 'nt':
    #     total_nt = sum([len(seq) for seq in dictionary.values() ])
    batches = []
    sub_dict = {}
    curr_nt_count = 0
    for i, (acc, seq) in enumerate(dictionary.items()):
        curr_nt_count += len(seq)
        if curr_nt_count >= size:
            sub_dict[acc] = seq
            batches.append(sub_dict)
            sub_dict = {}
            curr_nt_count = 0
        else:
            sub_dict[acc] = seq

    if curr_nt_count/size != 0:
        sub_dict[acc] = seq
        batches.append(sub_dict)
    
    return batches

    # for i, (acc, seq) in enumerate(dictionary.items()):
    #     if i > 0 and i % size == 0:
    #         batches.append(sub_dict)
    #         sub_dict = {}
    #         sub_dict[acc] = seq
    #     else:
    #         sub_dict[acc] = seq

    # if i/size != 0:
    #     sub_dict[acc] = seq
    #     batches.append(sub_dict)
    
    # return batches

def check_alignment_fit(aln_ultra, aln_other):
    """
        returns: 
        1. the differnce in scoring is positive if uLTRA better
        2. the classification obtained by uLTRA
    """
    diffs = {1,2,8} # cigar IDs for INS, DEL, SUBS
    matches_ultra = sum([length for type_, length in aln_ultra.cigartuples if type_ == 7])
    diffs_ultra = sum([length for type_, length in aln_ultra.cigartuples if type_ in diffs]) 
    matches_other = sum([length for type_, length in aln_other.cigartuples if type_ == 7])
    diffs_other = sum([length for type_, length in aln_other.cigartuples if type_ in diffs]) 
    # print(matches_ultra, diffs_ultra, matches_other, diffs_other, matches_other - diffs_other <= matches_ultra - diffs_ultra)
    # return matches_other - diffs_other <= matches_ultra - diffs_ultra
    return (matches_ultra - diffs_ultra) - (matches_other - diffs_other), aln_ultra.get_tag('XC')


def output_final_alignments( ultra_alignments_path, path_indexed_aligned, path_unindexed_aligned):
    # read in all reads from alternative aligner that was also mapped by uLTRA
    
    alt_alignments_file = pysam.AlignmentFile(path_indexed_aligned, "r", check_sq=False)
    alt_alignments = { read.query_name : read for read in alt_alignments_file.fetch(until_eof=True) if not read.is_secondary }


    alignment_infile = pysam.AlignmentFile( ultra_alignments_path, "r" )
    tmp_merged_outfile = pysam.AlignmentFile( ultra_alignments_path.decode()+ 'tmp', "w", template= alignment_infile)
    replaced_unaligned_cntr = 0
    tot_counter = 0
    scoring_dict = defaultdict(int)
    for read in alignment_infile.fetch():
        if not read.is_secondary:
            tot_counter += 1

        if read.query_name in alt_alignments:
            if read.flag == 4:
               read = alt_alignments[ read.query_name ] # replace unmapped uLTRA read with alternative alignment if mapped
               replaced_unaligned_cntr += 1
            elif not read.is_secondary: 
                ultra_scoring_diff, classification = check_alignment_fit(read,  alt_alignments[ read.query_name ])
                scoring_dict[ultra_scoring_diff] += 1
                if ultra_scoring_diff < 0:
                    read = alt_alignments[ read.query_name ] # replace uLTRA read with alternative alignment if better fit

        tmp_merged_outfile.write(read)
    alignment_infile.close()
    # path_genomic_aligned = os.path.join(args.outfolder, "unindexed.sam")

    # add all reads that we did not attempt to align with uLTRA
    # these reads had a primary alignment to unindexed regions by other pre-processing aligner (minimap2 as of now)
    not_attempted_cntr = 0
    unindexed = pysam.AlignmentFile(path_unindexed_aligned, "r")
    for read in unindexed.fetch():
        tmp_merged_outfile.write(read)
        if not read.is_secondary: 
            not_attempted_cntr += 1
    unindexed.close()
    tmp_merged_outfile.close()
    print("{0} reads were not attempted to be aligned with ultra, instead alternative aligner was used.".format(not_attempted_cntr))
    print("{0} reads with primary alignments were replaced with alternative aligner because they were unaligned with uLTRA.".format(replaced_unaligned_cntr))
    print("{0} primary alignments had best fit with uLTRA.".format(sum([v for k,v in scoring_dict.items() if k > 0])))
    print("{0} primary alignments had equal fit.".format(scoring_dict[0]))
    print("{0} primary alignments had best fit with alternative aligner.".format(sum([v for k,v in scoring_dict.items() if k < 0])))

    bin_boundaries = [-2**32, -100,-50,-20,-10,-5,-4,-3,-2,-1, 0, 1, 2, 3, 4, 5, 10, 20, 50, 100, 2**32]
    n = len(bin_boundaries)
    counts = [0]*n #{ (b_l, b_u) : 0 for b_l, b_u in zip(bin_boundaries[:-1], bin_boundaries[1:])}
    start_next = 0
    for k in sorted(scoring_dict.keys()):
        for i in range(start_next, n):
            b = bin_boundaries[i]
        # for i, b in enumerate(bin_boundaries):
            if k < b:
                counts[i] += scoring_dict[k]
            else:
                start_next = i


    print("Table of score-difference between alignment methods (negative number: alternative aligner better fit, positive number is uLTRA better fit)")
    print("Score is calculated as sum(identities) - sum(ins, del, subs)")
    print("Format: Score difference: Number of primary alignments ")
    for i in range(len(counts)-1):
        print("[{0} - {1}): {2}".format(bin_boundaries[i],bin_boundaries[i+1], counts[i+1] - counts[i]))
    # print("{0} read with primary alignments aligned with uLTRA.".format(tot_counter - replaced_unaligned_cntr - replaced_fit_cntr))

    shutil.move(ultra_alignments_path.decode()+ 'tmp', ultra_alignments_path)


def align_reads(args):
    if args.nr_cores > 1:
        mp.set_start_method('spawn')
        print(mp.get_context())
        print("Environment set:", mp.get_context())
        print("Using {0} cores.".format(args.nr_cores))

    # topological_sorts = help_functions.pickle_load( os.path.join(args.outfolder, 'top_sorts.pickle') )
    # path_covers = help_functions.pickle_load( os.path.join(args.outfolder, 'paths.pickle') )

    ref_part_sequences = help_functions.pickle_load( os.path.join(args.outfolder, 'ref_part_sequences.pickle') )
    refs_id_lengths = help_functions.pickle_load( os.path.join(args.outfolder, 'refs_id_lengths.pickle') )
    refs_lengths = help_functions.pickle_load( os.path.join(args.outfolder, 'refs_lengths.pickle') )

    if not args.disable_mm2:
        print("Filtering reads aligned to unindexed regions with minimap2 ")
        nr_reads_to_ignore, path_reads_to_align = prefilter_genomic_reads.main(ref_part_sequences, args.ref, args.reads, args.outfolder, args.nr_cores, args.genomic_frac, args.mm2_ksize)
        args.reads = path_reads_to_align
        print("Done filtering. Reads filtered:{0}".format(nr_reads_to_ignore))

    # print(ref_part_sequences)
    ref_path = os.path.join(args.outfolder, "refs_sequences.fa")
    refs_file = open(ref_path, 'w') #open(os.path.join(outfolder, "refs_sequences_tmp.fa"), "w")
    for sequence_id, seq  in ref_part_sequences.items():
        chr_id, start, stop = unpack('LLL',sequence_id)
        # for (start,stop), seq  in ref_part_sequences[chr_id].items():
        refs_file.write(">{0}\n{1}\n".format(str(chr_id) + str("^") + str(start) + "^" + str(stop), seq))
    refs_file.close()

    del ref_part_sequences

    ######### FIND MEMS WITH MUMMER #############
    #############################################
    #############################################

    mummer_start = time()
    if args.nr_cores == 1:
        print("Processing reads for MEM finding")
        reads_tmp = open(os.path.join(args.outfolder, 'reads_tmp.fq'), 'w')
        for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r')):
            # print(seq)
            # print(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA))
            reads_tmp.write('>{0}\n{1}\n'.format(acc, help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5)))
        reads_tmp.close()
        args.reads_tmp = reads_tmp.name
        print("Finished processing reads for MEM finding ")

        mummer_out_path =  os.path.join( args.outfolder, "mummer_mems_batch_-1.txt" )
        print("Running MEM finding forward") 
        mem_wrapper.find_mems_slamem(args.outfolder, args.reads_tmp, ref_path, mummer_out_path, args.min_mem)
        print("Completed MEM finding forward")

        print("Processing reverse complement reads for MEM finding")
        reads_rc = open(os.path.join(args.outfolder, 'reads_rc.fq'), 'w')
        for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r')):
            # print(help_functions.reverse_complement(seq))
            # print(help_functions.remove_read_polyA_ends(help_functions.reverse_complement(seq), args.reduce_read_ployA))
            reads_rc.write('>{0}\n{1}\n'.format(acc, help_functions.reverse_complement(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5))))
        reads_rc.close()
        args.reads_rc = reads_rc.name
        print("Finished processing reverse complement reads for MEM finding")

        mummer_out_path =  os.path.join(args.outfolder, "mummer_mems_batch_-1_rc.txt" )
        print("Running MEM finding reverse")
        mem_wrapper.find_mems_slamem(args.outfolder, args.reads_rc, ref_path, mummer_out_path, args.min_mem)
        print("Completed MEM finding reverse")
    
    else:
        reads = { acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
        total_nt = sum([len(seq) for seq in reads.values() ])
        batch_size = int(total_nt/int(args.nr_cores) + 1)
        print("batch nt:", batch_size, "total_nt:", total_nt)
        read_batches = batch(reads, batch_size, 'nt')
        
        #### TMP remove not to call mummer repeatedly when bugfixing #### 
        
        batch_args = []
        for i, read_batch_dict in enumerate(read_batches):
            print(len(read_batch_dict))
            read_batch_temp_file = open(os.path.join(args.outfolder, "reads_batch_{0}.fa".format(i)), "w")
            read_batch_temp_file_rc = open(os.path.join(args.outfolder, "reads_batch_{0}_rc.fa".format(i) ), "w")
            for acc, seq in read_batch_dict.items():
                read_batch_temp_file.write('>{0}\n{1}\n'.format(acc, help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5)))
            read_batch_temp_file.close()

            for acc, seq in read_batch_dict.items():
                read_batch_temp_file_rc.write('>{0}\n{1}\n'.format(acc, help_functions.reverse_complement(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5))))
            read_batch_temp_file_rc.close()
            
            read_batch = read_batch_temp_file.name
            read_batch_rc = read_batch_temp_file_rc.name
            mummer_batch_out_path =  os.path.join( args.outfolder, "mummer_mems_batch_{0}.txt".format(i) )
            mummer_batch_out_path_rc =  os.path.join(args.outfolder, "mummer_mems_batch_{0}_rc.txt".format(i) )
            batch_args.append( (args.outfolder, read_batch, ref_path, mummer_batch_out_path, args.min_mem ) )
            batch_args.append( (args.outfolder, read_batch_rc, ref_path, mummer_batch_out_path_rc, args.min_mem ) )

        pool = Pool(processes=int(args.nr_cores))
        results = pool.starmap(mem_wrapper.find_mems_slamem, batch_args)
        pool.close()
        pool.join() 
        
        ####################################################################

        pass

    print("Time for mummer to find mems:{0} seconds.".format(time()-mummer_start))
    #############################################
    #############################################
    #############################################


    print("Starting aligning reads.")

    if args.nr_cores == 1:
        reads = { acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
        classifications, alignment_outfile_name = align.align_single(reads, refs_id_lengths, args, -1)
    else:
        # batch reads and mems up: divide reads by  nr_cores to get batch size
        # then write to separate SAM-files with a batch index, 
        # finally merge sam file by simple cat in terminal 
        aligning_start = time()
        batch_size = int(len(reads)/int(args.nr_cores) + 1)
        # read_batches = batch(reads, batch_size)
        print('Nr reads:', len(reads), "nr batches:", len(read_batches), [len(b) for b in read_batches])
        classifications, alignment_outfiles = align.align_parallel(read_batches, refs_id_lengths, args)
    
        print("Time to align reads:{0} seconds.".format(time()-aligning_start))

        # Combine samfiles produced from each batch
        combine_start = time()
        # print(refs_lengths)
        alignment_outfile = pysam.AlignmentFile( os.path.join(args.outfolder, "reads.sam"), "w", reference_names=list(refs_lengths.keys()), reference_lengths=list(refs_lengths.values()) ) #, template=samfile)
    
        for f in alignment_outfiles:
            samfile = pysam.AlignmentFile(f, "r")
            for read in samfile.fetch():
                alignment_outfile.write(read)
            samfile.close()

        alignment_outfile.close()
        alignment_outfile_name = alignment_outfile.filename
        print("Time to merge SAM-files:{0} seconds.".format(time() - combine_start))

    # need to merge genomic/unindexed alignments with the uLTRA-aligned alignments
    if not args.disable_mm2:
        path_indexed_aligned = os.path.join(args.outfolder, "indexed.sam")
        path_unindexed_aligned = os.path.join(args.outfolder, "unindexed.sam")
        output_final_alignments(alignment_outfile_name, path_indexed_aligned, path_unindexed_aligned)

    counts = defaultdict(int)
    alignment_coverage = 0
    for read_acc in reads:
        if read_acc not in classifications:
            # print(read_acc, "did not meet the threshold")
            pass
        elif classifications[read_acc][0] != 'FSM':
            # print(read_acc, classifications[read_acc]) 
            pass
        if read_acc in classifications:
            alignment_coverage += classifications[read_acc][1]
            if classifications[read_acc][1] < 1.0:
                # print(read_acc, 'alignemnt coverage:', classifications[read_acc][1])
                pass
            counts[classifications[read_acc][0]] += 1
        else:
            counts['unaligned'] += 1


    print(counts)
    print("total alignment coverage:", alignment_coverage)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="uLTRA -- Align and classify long transcriptomic reads based on colinear chaining algorithms to gene regions")
    parser.add_argument('--version', action='version', version='%(prog)s 0.0.2.4')

    subparsers = parser.add_subparsers(help='Subcommands for eaither constructing a graph, or align reads')
    # parser.add_argument("-v", help='Different subcommands for eaither constructing a graph, or align reads')

    pipeline_parser = subparsers.add_parser('pipeline', help= "Perform all in one: prepare splicing database and reference sequences and align reads.")
    indexing_parser = subparsers.add_parser('index', help= "Construct data structures used for alignment.")
    align_reads_parser = subparsers.add_parser('align', help="Classify and align reads with colinear chaining to DAGs")

    pipeline_parser.add_argument('ref', type=str, help='Reference genome (fasta)')
    pipeline_parser.add_argument('gtf', type=str, help='Path to gtf or gtf file with gene models.')
    pipeline_parser.add_argument('reads', type=str, help='Path to fasta/fastq file with reads.')
    pipeline_parser.add_argument('outfolder', type=str, help='Path to fasta file with a nucleotide sequence (e.g., gene locus) to simulate isoforms from.')
    group = pipeline_parser.add_mutually_exclusive_group()
    group.add_argument('--ont', action="store_true", help='Set parameters suitable for ONT (Currently sets: --min_mem 15, --min_acc 0.6 --alignment_threshold 0.5).')
    group.add_argument('--isoseq', action="store_true", help='Set parameters suitable for IsoSeq (Currently sets: --min_mem 17, --min_acc 0.8 --alignment_threshold 0.5).')
    # group2 = pipeline_parser.add_mutually_exclusive_group()
    # group2.add_argument('--mummer', action="store_true", help='Use mummer to find mems. About 1.5-2x faster than slamem but consumes  >4x more memory (slamem is recommended for human and larger)')
    # group2.add_argument('--slamem', action="store_true", help='Use slaMEM to find mems. About 1.5-2x slower than mumer but consumes less than 25% of the memory compared to mummer')

    pipeline_parser.add_argument('--min_mem', type=int, default=17, help='Threshold for minimum mem size considered.')
    pipeline_parser.add_argument('--min_segm', type=int, default=25, help='Threshold for minimum segment size considered.')
    pipeline_parser.add_argument('--min_acc', type=float, default=0.5, help='Minimum accuracy of MAM to be considered in mam chaining.')
    pipeline_parser.add_argument('--flank_size', type=int, default=1000, help='Size of genomic regions surrounding genes.')
    pipeline_parser.add_argument('--max_intron', type=int, default=1200000, help='Set global maximum size between mems considered in chaining solution. This is otherwise inferred from GTF file per chromosome.')
    pipeline_parser.add_argument('--small_exon_threshold', type=int, default=200, help='Considered in MAM solution even if they cont contain MEMs.')
    pipeline_parser.add_argument('--reduce_read_ployA', type=int, default=8, help='Reduces polyA tails longer than X bases (default 10) in reads to 5bp before MEM finding. This helps MEM matching to spurios regions but does not affect final read alignments.')
    pipeline_parser.add_argument('--alignment_threshold', type=int, default=0.5, help='Lower threshold for considering an alignment. \
                                                                                        Counted as the difference between total match score and total mismatch penalty. \
                                                                                        If a read has 25%% errors (under edit distance scoring), the difference between \
                                                                                        matches and mismatches would be (very roughly) 0.75 - 0.25 = 0.5 with default alignment parameters \
                                                                                        match =2, subs=-2, gap open 3, gap ext=1. Default val (0.5) sets that a score\
                                                                                        higher than 2*0.5*read_length would be considered an alignment, otherwise unaligned.')
    pipeline_parser.add_argument('--t', dest = 'nr_cores', type=int, default=3, help='Number of cores.')
    pipeline_parser.add_argument('--non_covered_cutoff', type=int, default=15, help='Threshold for what is counted as varation/intron in alignment as opposed to deletion.')
    pipeline_parser.add_argument('--dropoff', type=float, default=0.95, help='Ignore alignment to hits with read coverage of this fraction less than the best hit.')
    pipeline_parser.add_argument('--max_loc', type=float, default=5, help='Limit read to be aligned to at most max_loc places (default 5).\
                                                                            This prevents time blowup for reads from highly repetitive regions (e.g. some genomic intra-priming reads)\
                                                                            but may limit all posible alignments to annotated gene families with many highly similar copies.')
    pipeline_parser.add_argument('--ignore_rc', action='store_true', help='Ignore to map to reverse complement.')
    pipeline_parser.add_argument('--disable_infer', action='store_true', help='Makes splice creation step much faster. Thes parameter can be set if gene and transcript name fields are provided in gtf file.')
    pipeline_parser.add_argument('--mask_threshold', type=int, default=200, help='Threshold for what is counted as varation/intron in alignment as opposed to deletion.')
    pipeline_parser.add_argument('--disable_mm2', action='store_true', help='Disables utilizing minimap2 to detect genomic primary alignments and to quality check uLTRAs primary alignments.\
                                                                                An alignment is classified as genomic if more than --genomic_frac (default 10%%) of its aligned length is outside\
                                                                                regions indexed by uLTRA. Note that uLTRA indexes flank regions such as 3 prime, 5 prime and (parts of) introns.')
    pipeline_parser.add_argument('--genomic_frac', type=float, default=0.1, help='If parameter prefilter_genomic is set, this is the threshild for fraction of aligned portion of read that is outside uLTRA indexed regions to be considered genomic (default 0.1).')

    pipeline_parser.set_defaults(which='pipeline')

    indexing_parser.add_argument('ref', type=str, help='Reference genome (fasta)')
    indexing_parser.add_argument('gtf', type=str, help='Path to gtf or gtf file with gene models.')
    indexing_parser.add_argument('outfolder', type=str, help='Path to output folder.')
    indexing_parser.add_argument('--min_segm', type=int, default=25, help='Threshold for minimum segment size considered.')
    indexing_parser.add_argument('--flank_size', type=int, default=1000, help='Size of genomic regions surrounding genes.')
    indexing_parser.add_argument('--small_exon_threshold', type=int, default=200, help='Considered in MAM solution even if they cont contain MEMs.')
    indexing_parser.add_argument('--disable_infer', action='store_true', help='Makes splice creation step much faster. Thes parameter can be set if gene and transcript name fields are provided in gtf file.')
    indexing_parser.add_argument('--min_mem', type=int, default=17, help='Threshold for minimum mem size considered.')
    indexing_parser.add_argument('--mask_threshold', type=int, default=200, help='Threshold for what is counted as varation/intron in alignment as opposed to deletion.')
    indexing_parser.set_defaults(which='index')


    align_reads_parser.add_argument('ref', type=str, help='Path to fasta file with a nucleotide sequence (e.g., gene locus) to simulate isoforms from.')    
    align_reads_parser.add_argument('reads', type=str, help='Path to fasta file with a nucleotide sequence (e.g., gene locus) to simulate isoforms from.')
    align_reads_parser.add_argument('outfolder', type=str, help='Path to fasta file with a nucleotide sequence (e.g., gene locus) to simulate isoforms from.')   
    align_reads_parser.add_argument('--t', dest = 'nr_cores', type=int, default=3, help='Number of cores.')
    align_reads_parser.add_argument('--max_intron', type=int, default=1200000, help='Set global maximum size between mems considered in chaining solution. This is otherwise inferred from GTF file per chromosome.')
    align_reads_parser.add_argument('--reduce_read_ployA', type=int, default=8, help='Reduces polyA tails longer than X bases (default 10) in reads to 5bp before MEM finding. This helps MEM matching to spurios regions but does not affect final read alignments.')
    align_reads_parser.add_argument('--alignment_threshold', type=int, default=0.5, help='Lower threshold for considering an alignment. \
                                                                                        Counted as the difference between total match score and total mismatch penalty. \
                                                                                        If a read has 25%% errors (under edit distance scoring), the difference between \
                                                                                        matches and mismatches would be (very roughly) 0.75 - 0.25 = 0.5 with default alignment parameters \
                                                                                        match =2, subs=-2, gap open 3, gap ext=1. Default val (0.5) sets that a score\
                                                                                        higher than 2*0.5*read_length would be considered an alignment, otherwise unaligned.')
    align_reads_parser.add_argument('--non_covered_cutoff', type=int, default=15, help='Threshold for what is counted as varation/intron in alignment as opposed to deletion.')
    align_reads_parser.add_argument('--dropoff', type=float, default=0.95, help='Ignore alignment to hits with read coverage of this fraction less than the best hit.')
    align_reads_parser.add_argument('--max_loc', type=float, default=5, help='Limit read to be aligned to at most max_loc places (default 5).\
                                                                            This prevents time blowup for reads from highly repetitive regions (e.g. some genomic intra-priming reads)\
                                                                            but may limit all posible alignments to annotated gene families with many highly similar copies.')

    align_reads_parser.add_argument('--ignore_rc', action='store_true', help='Ignore to map to reverse complement.')
    align_reads_parser.add_argument('--min_mem', type=int, default=17, help='Threshold for what is counted as varation/intron in alignment as opposed to deletion.')
    align_reads_parser.add_argument('--min_acc', type=float, default=0.5, help='Minimum accuracy of MAM to be considered in mam chaining.')

    align_reads_parser.add_argument('--disable_mm2', action='store_true', help='Disables utilizing minimap2 to detect genomic primary alignments and to quality check uLTRAs primary alignments.\
                                                                                An alignment is classified as genomic if more than --genomic_frac (default 10%%) of its aligned length is outside\
                                                                                regions indexed by uLTRA. Note that uLTRA indexes flank regions such as 3 prime, 5 prime and (parts of) introns.')

    align_reads_parser.add_argument('--genomic_frac', type=float, default=0.1, help='If parameter prefilter_genomic is set, this is the threshild for fraction of aligned portion of read that is outside uLTRA indexed regions to be considered genomic (default 0.1).')

    group2 = align_reads_parser.add_mutually_exclusive_group()
    group2.add_argument('--ont', action="store_true", help='Set parameters suitable for ONT (Currently sets: --min_mem 15, --min_acc 0.6 --alignment_threshold 0.5).')
    group2.add_argument('--isoseq', action="store_true", help='Set parameters suitable for IsoSeq (Currently sets: --min_mem 17, --min_acc 0.8 --alignment_threshold 0.5).')


    align_reads_parser.set_defaults(which='align_reads')

    args = parser.parse_args()
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit()

    help_functions.mkdir_p(args.outfolder)
    if len(sys.argv)==1:
        parser.print_help()
        sys.exit()

    if args.which == 'align_reads' or args.which == 'pipeline':
        args.mm2_ksize = 15
        if args.ont:
            args.min_mem = 17
            args.min_acc = 0.6
            args.mm2_ksize = 14
            # args.alignment_threshold = 0.5
        if args.isoseq:
            args.min_mem = 20
            args.min_acc = 0.8
            # args.alignment_threshold = 0.5


    if args.which == 'index':
        refs, refs_lengths = load_reference(args)
        prep_splicing(args, refs_lengths)
        prep_seqs(args, refs, refs_lengths)
    elif args.which == 'align_reads':
        align_reads(args)
    elif args.which == 'pipeline':
        refs, refs_lengths = load_reference(args)
        prep_splicing(args, refs_lengths)
        prep_seqs(args, refs, refs_lengths)
        align_reads(args)        
    else:
        print('invalid call')
