# -*- coding: utf-8 -*-
"""
amplimap's ``Snakefile`` is used by the ``amplimap`` command-line executable
to determine the shell commands and library functions to execute to generate a given file.
Each "rule" in this file either executes a shell command (eg. running ``bwa``),
executes a small set of Python commands or calls a function from the amplimap Python package.

When running this manually set ``config['general']['amplimap_parent_dir']`` to specify the
directory to load amplimap from, otherwise the amplimap package installed system-wide will be used.
"""

import os
import shutil
import sys

#make sure we can import amplimap, and we are using the correct path
#by adding it to the beginning of sys.path (but fall back on default if we don't have it)
try:
    sys.path.insert(0, config['general']['amplimap_parent_dir'])
except KeyError:
    sys.stderr.write('Failed to add exact amplimap directory to Python path, using system version of amplimap.\n')

#version and title
from amplimap.version import __title__, __version__

#list of rules that don't run on the cluster
localrules: all, \
    start_analysis, check_test_variant, \
    copy_probes, copy_snps, copy_targets_bed, convert_targets_csv,
    link_tagged_bais_1, link_tagged_bais_2, link_tagged_bams, link_unmapped_bams, \
    convert_mipgen_probes, convert_mipgen_probes_txt, convert_heatseq_probes, \
    link_reads, merge_targets, stats_samples_agg, stats_reads_agg, stats_alignment_agg, coverage_agg, stats_replacements_agg, \
    variants_summary, variants_summary_excel

#find samples
if os.path.isdir('bams_in'):
    SAMPLES, = glob_wildcards("bams_in/{sample_full}.bam")
    config['general']['lanes_actual'] = 1
elif os.path.isdir('mapped_bams_in'):
    SAMPLES, = glob_wildcards("mapped_bams_in/{sample_full}.bam")
    config['general']['lanes_actual'] = 1
elif os.path.isdir('unmapped_bams_in'):
    SAMPLES, = glob_wildcards("unmapped_bams_in/{sample_full}.bam")
    config['general']['lanes_actual'] = 1
else:
    #check that samples for R1 and R2 agree
    samples_with_lane1, = glob_wildcards("reads_in/{sample_with_lane}_R1_001.fastq.gz") 
    samples_with_lane2, = glob_wildcards("reads_in/{sample_with_lane}_R1_001.fastq.gz") 
    samples_with_lane1.sort()
    samples_with_lane2.sort()
    assert samples_with_lane1 == samples_with_lane2, 'Did not find R1 and R2 fastq files for all samples!'

    #now get list of samples and lanes
    SAMPLES_FULL, LANES = glob_wildcards("reads_in/{sample_full}_L{lane}_R1_001.fastq.gz")

    #make unique
    SAMPLES = list(set(SAMPLES_FULL))

    #auto-detect number of input lanes
    if config['general']['lanes'] == 0:
        config['general']['lanes_actual'] = max((int(l) for l in LANES)) if len(LANES) > 0 else 1
    else:
        config['general']['lanes_actual'] = config['general']['lanes']

#sort and make sure we don't use "Undetermined" sample
SAMPLES.sort()
SAMPLES = [s for s in SAMPLES if s != 'Undetermined_S0']
assert len(SAMPLES) > 0, '\n\nNo samples found. Please check input filenames.\n\n'

#versions
VERSIONS = {}
VERSIONS['_amplimap'] = str(__version__)

#helper function to load module
def load_module_command(name, config):
    if 'modules' in config and name in config['modules'] and len(config['modules'][name]) > 0:
        return 'module load {};'.format(config['modules'][name])
    else:
        return ''

#helper to get paths for genome
def get_path(name, my_config = None, default = None):
    if my_config is None:
        my_config = config

    if not my_config['general']['genome_name'] in my_config['paths']:
        raise Exception('Could not find list of paths for genome_name: "{}".'.format(my_config['general']['genome_name']))

    if not name in my_config['paths'][my_config['general']['genome_name']] \
    or my_config['paths'][my_config['general']['genome_name']][name] is None \
    or len(my_config['paths'][my_config['general']['genome_name']][name]) == 0:
        if default is None:
            raise Exception('\nCould not find path for "%s" for genome "%s"! Please specify this in your configuration file under:\n\npaths:\n  %s:\n    %s: /my/path\n' % (
                name, my_config['general']['genome_name'], my_config['general']['genome_name'], name))
        else:
            return default
    else:
        return my_config['paths'][my_config['general']['genome_name']][name]

#helper to get annovar name for genome
def get_annovar_name():
    #use get_path here even though this is of course not a path
    annovar_name = get_path('annovar_name', default = '') #empty default so we don't die if this is missing
    if len(annovar_name) == 0:
        annovar_name = config['general']['genome_name']
    return annovar_name

#helper to figure out which index to use
#used with False to get actual parameter for tool call, and with True to get path for Snakemake
def get_alignment_index_filename(real_path):
    if config['align']['aligner'] == 'naive':
        my_path = get_path('fasta')
    else:
        my_path = get_path(config['align']['aligner'])

    #masked index name
    if config['general']['mask_bed'] and len(config['general']['mask_bed']) > 0:
        if config['align']['aligner'] in ['naive', 'bwa', 'bowtie2']:
            my_path = 'masked_reference/genome.fa'
        elif config['align']['aligner'] == 'star':    
            my_path = 'masked_reference/star/'
        else:
            raise Exception('unexpected aligner')

    #add extension to index name to get actual filename
    if real_path:
        if config['align']['aligner'] == 'naive':
            return my_path+'.fai'
        elif config['align']['aligner'] == 'bwa':
            return my_path+'.bwt'
        elif config['align']['aligner'] == 'bowtie2':
            return my_path+'.1.bt2'
        elif config['align']['aligner'] == 'star':
            return os.path.join(my_path, 'Genome')
        else:
            raise Exception('unexpected aligner')
    else:
        return my_path

##### ALL-RULE (DEFAULT = FIRST RULE)
rule all:
    input:
        "analysis/reads_parsed/stats_samples.csv",
        "analysis/reads_parsed/stats_reads.csv" if not config['general']['use_raw_reads'] else [],
        "analysis/stats_alignment/stats_alignment.csv" if not config['general']['use_raw_reads'] else [],

##### CHECKING
rule start_analysis:
    output:
        #force this to be 'analysis' right now
        versions = "analysis/versions.yaml",
        samples = "analysis/samples.yaml",
        file_hashes = "analysis/file_hashes.yaml",
    run:
        import yaml

        with open(output['versions'], 'w') as f:
            yaml.dump(VERSIONS, f, default_flow_style=False)

        with open(output['samples'], 'w') as f:
            yaml.dump(SAMPLES, f, default_flow_style=False)

        import amplimap.reader
        with open(output['file_hashes'], 'w') as f:
            yaml.dump(amplimap.reader.get_file_hashes(), f, default_flow_style=False)
    
##### target rules
rule coverages:
    input:
        rules.all.input,
        "analysis/bams/coverages/coverage_full.csv",
        "analysis/bams_umi_dedup/coverages/coverage_full.csv" if not config['general']['ignore_umis'] else []

rule pileups_only:
    input:
        rules.start_analysis.output, #enforce packages (and thus also config) check here
        "analysis/pileups/pileups_long.csv" if os.path.isfile("targets.bed") or os.path.isfile("targets.csv") else [],
        "analysis/pileups/target_coverage_long.csv" if os.path.isfile("targets.bed") or os.path.isfile("targets.csv") else [],
        "analysis/pileups_snps/target_snps_pileups_long.csv" if os.path.isfile("snps.txt") else []

rule pileups:
    input:
        rules.all.input,
        rules.pileups_only.input,

#separate rule to only run SNP pileups
rule pileups_snps:
    input:
        rules.all.input,
        "analysis/pileups_snps/target_snps_pileups_long.csv",

rule variants:
    input:
        rules.all.input,
        "analysis/variants_raw/variants_summary.csv",
        "analysis/variants_raw/variants_summary_filtered.csv",

rule variants_umi:
    input:
        rules.all.input,
        "analysis/variants_umi/variants_summary.csv",
        "analysis/variants_umi/variants_summary_filtered.csv"
        
rule variants_low_frequency:
    input:
        rules.all.input,
        "analysis/variants_low_frequency/variants_summary.csv",
        "analysis/variants_low_frequency/variants_summary_filtered.csv"

rule bams:
    input:
        rules.all.input,
        expand("analysis/bams/{sample_full}.bam", sample_full = SAMPLES),
        expand("analysis/bams/{sample_full}.bam.bai", sample_full = SAMPLES)

rule consensus_bams:
    input:
        rules.all.input,
        expand("analysis/bams_consensus/{sample_full}_CONSENSUS.bam", sample_full = SAMPLES),
        expand("analysis/bams_consensus/{sample_full}_CONSENSUS.bam.bai", sample_full = SAMPLES)

rule dedup_bams:
    input:
        rules.all.input,
        expand("analysis/bams_umi_dedup/{sample_full}.bam", sample_full = SAMPLES),
        expand("analysis/bams_umi_dedup/{sample_full}.bam.bai", sample_full = SAMPLES)

rule test_variants:
    input:
        "test__{search}_{replace}_{percentage}/stats_replacements/stats_replacements.csv",
        "test__{search}_{replace}_{percentage}/variants_raw/variants_summary_filtered.csv"
    output:
        touch("test__{search}_{replace}_{percentage}/test_variants.done")

rule test_pileups:
    input:
        "test__{search}_{replace}_{percentage}/stats_replacements/stats_replacements.csv",
        "test__{search}_{replace}_{percentage}/pileups/pileups_long.csv"
    output:
        touch("test__{search}_{replace}_{percentage}/test_pileups.done")


##### TOOL CHECKING AND VERSIONS
rule tool_version:
    output:
        "{analysis_dir}/versions/{tool}.txt"
    run:
        print('Checking version for {} and writing output to {}'.format(
            wildcards['tool'],
            output[0]
        ))

        if wildcards['tool'] == 'mutect2':
            tool_name = 'gatk'
        else:
            tool_name = wildcards['tool']

        if tool_name == 'naive':
            with open(output[0], 'wt') as f:
                f.write('naive aligner from amplimap {}\n'.format(__version__))
        else:
            import subprocess

            commands = load_module_command(tool_name, config)

            if tool_name == 'bwa':
                commands += 'bwa'
            elif tool_name == 'bowtie2':
                commands += 'bowtie2 --version > "%s"' % (output[0])
            elif tool_name == 'star':
                commands += 'STAR --version > "%s"' % (output[0])

            elif tool_name == 'samtools':
                commands += 'samtools --version > "%s"' % (output[0])
            elif tool_name == 'bedtools':
                commands += 'bedtools --version > "%s"' % (output[0])
            elif tool_name == 'bcftools':
                commands += 'bcftools --version > "%s"' % (output[0])
            elif tool_name == 'picard':
                commands += '%s SamToFastq --version' % (config['tools']['picard']['prefix'])

            elif tool_name == 'annovar':
                commands += 'table_annovar.pl'

            elif tool_name == 'platypus':
                commands += 'platypus > "%s"' % (output[0])
            elif tool_name == 'gatk':
                commands += 'gatk HaplotypeCaller --version > "%s"' % (output[0])
            elif tool_name == 'mutect2':
                commands += 'gatk Mutect2 --version > "%s"' % (output[0])
            elif tool_name in config['tools'] and 'version_command' in config['tools'][tool_name]:
                commands += config['tools'][tool_name]['version_command']
            else:
                raise Exception('Unexpected tool: {}/{}. Please report this bug.'.format(
                    tool_name, wildcards['tool']
                ))

            #run command and then manually check result, so that we can handle cases like
            #bwa and we can provide stderr if we get a returncode != 0
            result = subprocess.run(
                commands,
                stdout = subprocess.PIPE,
                stderr = subprocess.PIPE,
                shell = True,
                check = False
            )

            #normally we just check for a 0 exit status
            if result.returncode != 0:
                #bwa and annovar will return an error code if we run it like this, but it's the only way to get the version
                if tool_name in ['bwa', 'annovar', 'picard']:
                    #bwa has version number in stdERR
                    if tool_name == 'bwa' \
                    and re.search(r'^Version:\s+.*$', result.stderr.decode('utf-8'), re.MULTILINE):
                        with open(output[0], 'wb') as f:
                            f.write(result.stderr)
                    #annovar has version number in stdOUT
                    elif tool_name == 'annovar' \
                    and re.search(r'^\s*Version:\s+.*$', result.stdout.decode('utf-8'), re.MULTILINE):
                        with open(output[0], 'wb') as f:
                            f.write(result.stdout)
                    #picard has version number in stdERR, should return 1 and version number should be only result
                    elif tool_name == 'picard' \
                    and result.returncode == 1 \
                    and re.match(r'^[0-9.]+.*$', result.stderr.decode('utf-8')):
                        with open(output[0], 'wb') as f:
                            f.write(result.stderr)
                    #didn't get a version number, so tool may just not be available
                    else:
                        raise Exception(
                            '\nFailed to run tool {} and find version number (code={})\n{}\nPlease check that {} is available.\n'.format(
                                wildcards['tool'],
                                result.returncode,
                                result.stderr.decode('utf-8'),
                                wildcards['tool'],
                            )
                        )
                else:
                    raise Exception(
                        '\nFailed to run tool {} (code={})\n{}\nPlease check that {} is available.\n'.format(
                            wildcards['tool'],
                            result.returncode,
                            result.stderr.decode('utf-8'),
                            wildcards['tool'],
                        )
                    )
            else:
                #check that we actually got the right version
                if tool_name == 'bedtools':
                    pass  # TODO: should check that we have at least v2.24.0, since the interface for bedtools coverage changed
                elif tool_name == 'picard':
                    pass  # TODO: should we require v2+?
                else:
                    pass #we succeeded


##### TARGET CONVERSION
rule copy_targets_bed:
    input:
        targets = "targets.bed"
    output:
        targets = "{analysis_dir}/targets.bed"
    run:
        print('Reading', input['targets'])

        import amplimap.reader
        targets = amplimap.reader.read_targets(input['targets'], reference_type = 'genome', file_type = 'bed')        
        print(input['targets'], 'checked.')

        amplimap.reader.write_targets_bed(output['targets'], targets)
        print(len(targets), 'written to', output['targets'])

rule convert_targets_csv:
    input:
        targets = "targets.csv"
    output:
        targets = "{analysis_dir}/targets.bed"
    run:
        print('Reading', input['targets'])

        import amplimap.reader
        targets = amplimap.reader.read_targets(input['targets'], reference_type = 'genome', file_type = 'csv')        
        print(input['targets'], 'checked.')

        amplimap.reader.write_targets_bed(output['targets'], targets)
        print(len(targets), 'written to', output['targets'])

##### PROBE CSV CONVERSION
rule copy_probes:
    input:
        rules.start_analysis.output,
        probes = "probes.csv"
    output:
        probes = "{analysis_dir}/probes.csv"
    run:
        import amplimap.reader
        design = amplimap.reader.read_new_probe_design(input['probes'], reference_type = 'genome')
        design.to_csv(output['probes'], index=False)

if os.path.isfile("probes_mipgen.csv"):
    rule convert_mipgen_probes:
        input:
            rules.start_analysis.output,
            mipgen = "probes_mipgen.csv"
        output:
            probes = "{analysis_dir}/probes.csv"
        run:
            import amplimap.reader
            design = amplimap.reader.read_and_convert_mipgen_probes(input['mipgen'])
            design.to_csv(output['probes'], index=False)

if os.path.isfile("picked_mips.txt"):
    rule convert_mipgen_probes_txt:
        input:
            rules.start_analysis.output,
            mipgen = "picked_mips.txt"
        output:
            probes = "{analysis_dir}/probes.csv"
        run:
            import amplimap.reader
            design = amplimap.reader.read_and_convert_mipgen_probes(input['mipgen'], sep='\t')
            design.to_csv(output['probes'], index=False)

if os.path.isfile("probes_heatseq.csv"):
    rule convert_heatseq_probes:
        input:
            rules.start_analysis.output,
            heatseq = "probes_heatseq.tsv"
        output:
            probes = "{analysis_dir}/probes.csv"
        run:
            import amplimap.reader
            design = amplimap.reader.read_and_convert_heatseq_probes(input['heatseq'])
            design.to_csv(output['probes'], index=False)

##### SNPS CHECKING/COPYING
rule copy_snps:
    input:
        snps = "snps.txt"
    output:
        snps = "{analysis_dir}/snps.txt"
    run:
        import amplimap.reader
        amplimap.reader.read_snps_txt(input['snps'], reference_type = 'genome')
        print(input['snps'], 'checked.')

        #just copy the file over to the analysis directory
        shutil.copyfile(input['snps'], output['snps'])
    
##### MASKING
rule mask_genome:
    input:
        ref_fasta = get_path('fasta'),
        mask_bed = config['general']['mask_bed'],
    output:
        masked_fasta = 'masked_reference/genome.fa',
        masked_fasta_index = 'masked_reference/genome.fa.fai',
    shell:
        """
        %s
        bedtools maskfasta \
            -fi {input[ref_fasta]:q} \
            -bed {input[mask_bed]:q} \
            -fo {output[masked_fasta]:q} \
        ;

        %s
        samtools faidx {output[masked_fasta]:q};
        """ % (load_module_command('bedtools', config), load_module_command('samtools', config))

rule mask_genome_index_bwa:
    input:
        masked_fasta = 'masked_reference/genome.fa',
    output:
        masked_index = 'masked_reference/genome.fa.bwt',
    shell:
        """        
        %s
        bwa index {input[masked_fasta]:q};
        """ % (load_module_command('bwa', config))

rule mask_genome_index_bowtie2:
    input:
        masked_fasta = 'masked_reference/genome.fa',
    output:
        masked_index = 'masked_reference/genome.fa.1.bt2',
    shell:
        """        
        %s
        bowtie2-build {input[masked_fasta]:q} {input[masked_fasta]:q};
        """ % (load_module_command('bowtie2', config))

rule mask_genome_index_star:
    input:
        masked_fasta = 'masked_reference/genome.fa',
        annotation_gff = get_path('annotation_gff', default = ''),
    output:
        masked_index_directory = 'masked_reference/star/',
        masked_index_file = 'masked_reference/star/Genome',
    threads: 4
    shell:
        """        
        %s
        STAR \
            --runThreadN {threads} \
            --runMode genomeGenerate \
            --genomeDir {output[masked_index_directory]:q} \
            --genomeFastaFiles {input[masked_fasta]:q} \
            --sjdbGTFfile {input[annotation_gff]:q} \
            --sjdbOverhang 100 \
        ;
        """ % (load_module_command('star', config))

##### RAW BAM CONVERSION
if os.path.isdir('bams_in'):
    rule bam_to_fastq:
        input:
            "bams_in/{sample_full}.bam",
        output:
            r1 = "reads_in/{sample_full}_L001_R1_001.fastq.gz",
            r2 = "reads_in/{sample_full}_L001_R2_001.fastq.gz",
        run:
            shell("%s bedtools bamtofastq -i {input[0]:q} -fq {output[r1]:q} -fq2 {output[r2]:q}" % load_module_command('bedtools', config))   

##### MERGE RAW READS 
if os.path.isdir('reads_in') and config['general']['use_raw_reads']:
    rule merge_raw_reads:
        input:
            r1 = ["{analysis_dir}/reads_in/{sample_full}_L%03d_R1_001.fastq.gz" % (lane+1) for lane in range(config['general']['lanes_actual'])],
            r2 = ["{analysis_dir}/reads_in/{sample_full}_L%03d_R2_001.fastq.gz" % (lane+1) for lane in range(config['general']['lanes_actual'])],
            config_check = rules.start_analysis.output #make sure we copied the config we used
        output:
            r1 = "{analysis_dir}/reads_merged/{sample_full}__MERGED__R1_001.fastq.gz",
            r2 = "{analysis_dir}/reads_merged/{sample_full}__MERGED__R2_001.fastq.gz",
            #note that we only provide stats_samples if we actually have use_raw_reads set, otherwise we might accidentally use this rule instead of parse_reads
            stats_samples = "{analysis_dir}/reads_parsed/stats_samples__sample_{sample_full}.csv" if config['general']['use_raw_reads'] else [],
        run:
            import os
            assert len(input['r1']) == len(input['r2'])

            n_files = len(input['r1'])

            for read_number in ['r1', 'r2']:
                if n_files == 1:
                    #we don't actually need to merge, just symlink
                    os.symlink(os.path.relpath(input[read_number][0], os.path.dirname(output[read_number])), output[read_number])
                else:
                    #merge files
                    with open(output[read_number], 'wb') as wfd:
                        for f in input[read_number]:
                            with open(f, 'rb') as fd:
                                #use shutil to copy 10MB chunks at a time
                                shutil.copyfileobj(fd, wfd, 1024*1024*10)

            #count reads by looping through (merged) R1 file
            import gzip
            opener = gzip.open if output['r1'].endswith('.gz') else open
            n_fastq_lines = -1
            with opener(output['r1'], 'rt') as fmerged:
                for n_fastq_lines, _ in enumerate(fmerged):
                    pass
            n_fastq_lines += 1

            #we should have four lines per read
            assert n_fastq_lines % 4 == 0, 'encountered fastq file with %d lines, which is not a multiple of four' % (n_fastq_lines)
            n_pairs = int(n_fastq_lines // 4)

            #align_pe rule expects stats_samples.csv, which is made from per-sample stats_samples.csv
            #we will just fake this for now
            import amplimap.parse_reads
            amplimap.parse_reads.output_stats_samples(output['stats_samples'], {
                'sample': [wildcards['sample_full']],
                'files': [n_files],
                'pairs_total': [n_pairs],
                'pairs_unknown_arms': [0],
                'pairs_good_arms': [n_pairs],
                'pairs_r1_too_short': [0],
                'pairs_r2_too_short': [0]
            })

##### TAGGED, ALIGNED BAMS
if os.path.isdir('mapped_bams_in'):
    rule link_tagged_bams:
        input:
            "mapped_bams_in/{sample_full}.bam",
            config_check = rules.start_analysis.output
        output:
            "{analysis_dir}/bams/{sample_full}.bam",
            #note that we only provide stats_samples if we actually have use_raw_reads set, otherwise we might accidentally use this rule instead of parse_reads
            stats_samples = "{analysis_dir}/reads_parsed/stats_samples__sample_{sample_full}.csv" if config['general']['use_raw_reads'] else [],
        run:
            import os
            #could use os.path.abspath() here, but then we can't move working directory without breaking things
            os.symlink(os.path.join("..", "..", input[0]), output[0])

            #count reads (same as for unmapped_bam_to_fastqs)
            n_alignments = 0
            import pysam
            #need check_sq=False to handle unmapped file
            with pysam.AlignmentFile(input[0], "rb", check_sq=False) as bam:
                n_alignments = bam.count(until_eof=True)

            #align_pe rule expects stats_samples.csv, which is made from per-sample stats_samples.csv
            #we will just fake this for now
            import amplimap.parse_reads
            amplimap.parse_reads.output_stats_samples(output['stats_samples'], {
                'sample': [wildcards['sample_full']],
                'files': [1],
                'pairs_total': [n_alignments//2],
                'pairs_unknown_arms': [0],
                'pairs_good_arms': [n_alignments//2],
                'pairs_r1_too_short': [0],
                'pairs_r2_too_short': [0]
            })

    rule link_tagged_bais_1:
        input:
            "mapped_bams_in/{sample_full}.bam.bai",
        output:
            "{analysis_dir}/bams/{sample_full}.bam.bai",
        run:
            import os
            #could use os.path.abspath() here, but then we can't move working directory without breaking things
            os.symlink(os.path.join("..", "..", input[0]), output[0])

    rule link_tagged_bais_2:
        input:
            "mapped_bams_in/{sample_full}.bai",
        output:
            "{analysis_dir}/bams/{sample_full}.bam.bai",
        run:
            import os
            #could use os.path.abspath() here, but then we can't move working directory without breaking things
            os.symlink(os.path.join("..", "..", input[0]), output[0])

##### IDT INPUT
if os.path.isdir('unmapped_bams_in'):
    rule link_unmapped_bams:
        input:
            "unmapped_bams_in/{sample_full}.bam",
            config_check = rules.start_analysis.output
        output:
            "{analysis_dir}/bams_unmapped/{sample_full}.bam",
        run:
            import os
            #could use os.path.abspath() here, but then we can't move working directory without breaking things
            os.symlink(os.path.join("..", "..", input[0]), output[0])

    rule unmapped_bam_to_fastqs:
        input:
            "{analysis_dir}/bams_unmapped/{sample_full}.bam",
            "{analysis_dir}/versions/picard.txt",
        output:
            "{analysis_dir}/reads_merged/{sample_full}.UNTAGGED__MERGED__R1_001.fastq.gz",
            "{analysis_dir}/reads_merged/{sample_full}.UNTAGGED__MERGED__R2_001.fastq.gz",
            #note that we only provide stats_samples if we actually have use_raw_reads set, otherwise we might accidentally use this rule instead of parse_reads
            stats_samples = "{analysis_dir}/reads_parsed/stats_samples__sample_{sample_full}.csv" if config['general']['use_raw_reads'] else [],
        run:
            shell("""
                %s
                {config[tools][picard][prefix]} SamToFastq \
                    INPUT={input[0]:q} \
                    FASTQ={output[0]:q} \
                    SECOND_END_FASTQ={output[1]:q} \
                    INCLUDE_NON_PF_READS=false \
                    CLIPPING_ATTRIBUTE=XT \
                    CLIPPING_ACTION=X \
                    CLIPPING_MIN_LENGTH=10 \
                ;
            """ % load_module_command('picard', config))

            #count reads
            n_alignments = 0
            import pysam
            #need check_sq=False to handle unmapped file
            with pysam.AlignmentFile(input[0], "rb", check_sq=False) as bam:
                n_alignments = bam.count(until_eof=True)

            #align_pe rule expects stats_samples.csv, which is made from per-sample stats_samples.csv
            #we will just fake this for now
            import amplimap.parse_reads
            amplimap.parse_reads.output_stats_samples(output['stats_samples'], {
                'sample': [wildcards['sample_full']],
                'files': [1],
                'pairs_total': [n_alignments//2],
                'pairs_unknown_arms': [0],
                'pairs_good_arms': [n_alignments//2],
                'pairs_r1_too_short': [0],
                'pairs_r2_too_short': [0]
            })

    rule attach_tags_from_unmapped:
        input:
            "{analysis_dir}/bams_unmapped/{sample_full}.bam",
            #TODO: this uses the normal align_pe rule for the alignment, which includes coordinate sorting
            #the merge scripts wants queryname sort, but detects and fixes that automatically.
            #just a bit of a waste...
            "{analysis_dir}/bams/{sample_full}.UNTAGGED.bam",
            "{analysis_dir}/bams/{sample_full}.UNTAGGED.bam.bai",
            "{analysis_dir}/versions/picard.txt",
            "{analysis_dir}/versions/samtools.txt",
        output:
            "{analysis_dir}/bams/{sample_full}.bam",
            "{analysis_dir}/bams/{sample_full}.bam.bai",
            temp("{analysis_dir}/bams/{sample_full}.attach_tags_tmp.bam")
        run:
            shell("""
                %s%s

                #attach tags from unmapped bam
                {config[tools][picard][prefix]} MergeBamAlignment \
                    ALIGNED={input[1]:q} \
                    UNMAPPED={input[0]:q} \
                    OUTPUT={output[2]:q} \
                    REFERENCE_SEQUENCE="%s" \
                    EXPECTED_ORIENTATIONS=FR \
                    MAX_GAPS=-1 \
                    SORT_ORDER=coordinate \
                    ALIGNER_PROPER_PAIR_FLAGS=false \
                    CLIP_OVERLAPPING_READS=false \
                    VALIDATION_STRINGENCY=LENIENT \
                ;

                #fix the read groups, which will otherwise come from unmapped bam
                #note SILENT stringency to deal with invalid read groups from before
                #(LENIENT would still spam the logs with errors)
                {config[tools][picard][prefix]} AddOrReplaceReadGroups \
                    I={output[2]:q} \
                    O={output[0]:q} \
                    RGID={wildcards.sample_full:q} \
                    RGLB={wildcards.sample_full:q} \
                    RGPL=illumina \
                    RGPU={wildcards.sample_full:q} \
                    RGSM={wildcards.sample_full:q} \
                    VALIDATION_STRINGENCY=SILENT \
                ;

                samtools index {output[0]:q};
            """ % (load_module_command('picard', config), load_module_command('samtools', config), get_path('fasta')))

##### ACTUAL PIPELINE
if os.path.isdir('reads_in'):
    rule link_reads:
        input:
            "reads_in/{sample_with_lane}_R{read}_001.fastq.gz",
            config_check = rules.start_analysis.output
        output:
            "analysis/reads_in/{sample_with_lane}_R{read}_001.fastq.gz",
        run:
            import os
            os.symlink(os.path.relpath(input[0], os.path.dirname(output[0])), output[0])

if os.path.isdir('reads_in') and not config['general']['use_raw_reads']:
    rule parse_reads_pe:
        input:
            r1 = ["{analysis_dir}/reads_in/{sample_full}_L%03d_R1_001.fastq.gz" % (lane+1) for lane in range(config['general']['lanes_actual'])],
            r2 = ["{analysis_dir}/reads_in/{sample_full}_L%03d_R2_001.fastq.gz" % (lane+1) for lane in range(config['general']['lanes_actual'])],
            probes = "{analysis_dir}/probes.csv",
            config_check = rules.start_analysis.output #make sure we copied the config we used
        output:
            r1 = "{analysis_dir}/reads_parsed/fastq/{sample_full}__MIP_TRIMMED__R1_001.fastq.gz",
            r2 = "{analysis_dir}/reads_parsed/fastq/{sample_full}__MIP_TRIMMED__R2_001.fastq.gz",
            stats_samples = "{analysis_dir}/reads_parsed/stats_samples__sample_{sample_full}.csv",
            stats_reads = "{analysis_dir}/reads_parsed/stats_reads__sample_{sample_full}.csv"
        run:
            #prepare stats dicts
            import collections
            stats_samples = collections.OrderedDict()
            stats_reads = []

            #load probes
            import amplimap.reader
            #read new format
            probes = amplimap.reader.read_new_probe_design(input['probes'], reference_type = 'genome')
            #much faster to access than using DataFrame.loc every time
            probes_dict = probes.to_dict()

            #run
            import amplimap.parse_reads
            amplimap.parse_reads.parse_read_pairs(
                wildcards['sample_full'],
                input['r1'],
                input['r2'],
                output['r1'],
                output['r2'],
                probes_dict,
                stats_samples,
                stats_reads,
                unknown_arms_directory = os.path.join(wildcards['analysis_dir'], 'reads_parsed'),
                umi_one = config['parse_reads']['umi_one'], umi_two = config['parse_reads']['umi_two'],
                mismatches = config['parse_reads']['max_mismatches'],
                trim_primers = config['parse_reads']['trim_primers'],
                trim_min_length = config['parse_reads']['trim_min_length'],
                trim_primers_strict = False,
                trim_primers_smart = False,
                quality_trim_threshold = config['parse_reads']['quality_trim_threshold'] if config['parse_reads']['quality_trim_threshold'] != False else None,
                quality_trim_phred_base = config['parse_reads']['quality_trim_phred_base'],
                allow_multiple_probes = False,
                consensus_fastqs = None, #output['consensus_fastqs'],
                min_consensus_count = config['general']['umi_min_consensus_count'],
                min_consensus_fraction = config['general']['umi_min_consensus_percentage'] / 100.0
            )

            #output stats files
            amplimap.parse_reads.output_stats_samples(output['stats_samples'], stats_samples)
            amplimap.parse_reads.output_stats_reads(output['stats_reads'], stats_reads, config['parse_reads']['umi_one'], config['parse_reads']['umi_two'])

rule merge_targets:
    input:
        "{analysis_dir}/targets.bed",
        "{analysis_dir}/versions/bedtools.txt",
    output:
        "{analysis_dir}/targets_merged.bed"
    run:
        shell("%s bedtools merge -i <(sort -k1,1 -k2,2n {input[0]:q}) -c 4 -o collapse > {output[0]:q}" % load_module_command('bedtools', config))

rule stats_samples_agg:
    input:
        parsed_reads = expand("{{analysis_dir}}/reads_parsed/stats_samples__sample_{sample_full}.csv", sample_full = SAMPLES)
    output:
        "{analysis_dir}/reads_parsed/stats_samples.csv"
    run:
        import pandas as pd
        merged = None
        for file in input:
            print('Reading', file, '...')
            df = pd.read_csv(file, index_col = False)

            if merged is None:
                merged = df
            else:
                merged = merged.append(df, ignore_index = True)

        print('Merged data shape:', str(merged.shape))
        merged.sort_values(['sample'], inplace=True)

        assert merged['pairs_good_arms'].sum() > 0, (
            '\n\nABORTED: Did not find any read pairs with the expected primers sequences.'
            ' This is usually caused by incorrect primer sequences (eg. flipped or reverse complemented)'
            ' in the probes file or incorrect/missing UMI lengths in the config file.'
            ' The UMI length settings are currently {} bp for read one and {} bp for read two.'
            ' Please check the reads, probe design, and config files for errors.\n\n'
        ).format(
            config['parse_reads']['umi_one'],
            config['parse_reads']['umi_two']
        )

        #should we ensure that we have at least X good reads?
        percentage_good = (100.0 * merged['pairs_good_arms'].sum() / merged['pairs_total'].sum())
        print('Overall good read percentage =', percentage_good)

        if not percentage_good > config['parse_reads']['min_percentage_good']: #in [0, 100] checked above
            merged.to_csv(output[0]+'.error.csv', index = False)
            print('Wrote sample stats to %s' % output[0]+'.error.csv')

        assert percentage_good > config['parse_reads']['min_percentage_good'], \
            '\n\nABORTED: Found too few read pairs with the expected primers sequences. The average percentage across all samples was %.1f%%. Please check the reads, probe design, and config files for errors or lower the min_percentage_good setting in config.yaml to continue anyway.\n\n' % (percentage_good)

        merged.to_csv(output[0], index = False)

rule stats_reads_agg:
    input:
        parsed_reads = expand("{{analysis_dir}}/reads_parsed/stats_reads__sample_{sample_full}.csv", sample_full = SAMPLES)
    output:
        "{analysis_dir}/reads_parsed/stats_reads.csv",
    run:
        import pandas as pd
        merged = None
        for file in input:
            print('Reading', file, '...')
            try:
                df = pd.read_csv(file, index_col = False)
                print('Data shape:', str(df.shape))

                if merged is None:
                    merged = df
                else:
                    merged = merged.append(df, ignore_index = True)
            except pd.io.common.EmptyDataError:
                print('No data for', file, ', skipping.')

        assert merged is not None and len(merged) > 0, \
            '\n\nABORTED: Did not find any read pairs with the expected primers sequences. Please check the reads, probe design, and config files for errors.\n\n'

        print('Merged data shape:', str(merged.shape))
        merged.sort_values(['sample', 'probe'], inplace=True)

        merged.to_csv(output[0], index = False)

rule align_pe:
    input:
        #use either _MIP_TRIMMED_ or merged raw fastq files here, depending on config
        "{analysis_dir}/reads_merged/{sample_full}__MERGED__R1_001.fastq.gz" if config['general']['use_raw_reads'] \
            else "{analysis_dir}/reads_parsed/fastq/{sample_full}__MIP_TRIMMED__R1_001.fastq.gz",
        "{analysis_dir}/reads_merged/{sample_full}__MERGED__R2_001.fastq.gz" if config['general']['use_raw_reads'] \
            else "{analysis_dir}/reads_parsed/fastq/{sample_full}__MIP_TRIMMED__R2_001.fastq.gz",
        get_alignment_index_filename(real_path = True),
        "{analysis_dir}/reads_parsed/stats_samples.csv", #to make sure this gets generated and checked first
        "{analysis_dir}/versions/%s.txt" % config['align']['aligner'],
        "{analysis_dir}/versions/samtools.txt",
        probes = "{analysis_dir}/probes.csv" if config['align']['aligner'] == 'naive' else [], #the naive aligner also needs probes
    output:
        "{analysis_dir}/bams/{sample_full}.bam",
        "{analysis_dir}/bams/{sample_full}.bam.bai",
    threads: 4
    run:
        cmds = """
        %s
        %s
        %s
        samtools index {output[0]:q};
        """

        my_index = get_alignment_index_filename(real_path = False)
        if config['align']['aligner'] == 'naive':
            #load probes
            import amplimap.reader
            probes = amplimap.reader.read_new_probe_design(input['probes'], reference_type = 'genome')
            probes_dict = probes.to_dict()

            #just place the reads based on the probe location
            import amplimap.naive_mapper
            amplimap.naive_mapper.create_bam(
                wildcards['sample_full'],
                [input[0]],
                [input[1]],
                my_index,
                probes_dict,
                output[0]+"__unsorted",
                has_trimmed_primers = config['parse_reads']['trim_primers'] and not config['general']['use_raw_reads'],
                debug = False
            )

            #we still need to sort and index these            
            cmds = cmds % (
                load_module_command('samtools', config),
                '',
                """
                samtools sort \
                    -T "{output[0]}__sort_tmp" \
                    -o {output[0]:q} \
                    "{output[0]}__unsorted" \
                ;
                """
            )
        elif config['align']['aligner'] == 'bwa':
            cmds = cmds % (
                load_module_command('samtools', config),
                load_module_command('bwa', config),
                """
                bwa mem \
                    -t {threads} \
                    -R '@RG\\tID:{wildcards.sample_full}\\tSM:{wildcards.sample_full}' \
                    %s \
                    "%s" \
                    {input[0]:q} {input[1]:q} \
                | samtools fixmate -O bam - - \
                | samtools sort -T "{output[0]}__sort_tmp" \
                    -o {output[0]:q} - \
                ;
                """ % (
                    '' if config['general']['use_raw_reads'] else '-C', #only use -C parameter if we have processed reads
                    my_index
                )
            )
        elif config['align']['aligner'] == 'bowtie2':
            extra = ''
            if config['align']['bowtie2']['report_n'] > 1:
                extra += '-k %d' % config['align']['bowtie2']['report_n']

            cmds = cmds % (
                load_module_command('samtools', config),
                load_module_command('bowtie2', config),
                """
                bowtie2 \
                    %s \
                    --rg-id {wildcards.sample_full:q} \
                    --rg 'SM:{wildcards.sample_full}' \
                    --threads {threads} \
                    -x "%s" \
                    -1 {input[0]:q} \
                    -2 {input[1]:q} \
                | samtools sort -T "{output[0]}__sort_tmp" \
                    -o {output[0]:q} - \
                ;
                """ % (extra, my_index)
            )
        elif config['align']['aligner'] == 'star':
            extra = ''
            if input[0].endswith('.gz'):
                extra += ' --readFilesCommand "gzip -cd"'

            cmds = cmds % (
                load_module_command('samtools', config),
                load_module_command('star', config),
                """
                STAR \
                    %s \
                    --runMode alignReads \
                    --runThreadN {threads} \
                    --outSAMattrRGline ID:{wildcards.sample_full:q} SM:{wildcards.sample_full:q} \
                    --genomeDir "%s" \
                    --readFilesIn {input[0]:q} {input[1]:q} \
                    --outSAMtype BAM SortedByCoordinate \
                    --outFileNamePrefix "{output[0]}." \
                    --outSAMunmapped Within \
                    --outFilterType BySJout \
                    --outFilterMultimapNmax 5 \
                ;

                #fix names -- this will fail if no reads but should handle only unmapped thanks to --outSAMunmapped
                mv "{output[0]}.Aligned.sortedByCoord.out.bam" {output[0]:q};
                """ % (extra, my_index)
            )
        elif config['align']['aligner'] in config['tools'] and 'align_command' in config['tools'][config['align']['aligner']]:
            tool_command = config['tools'][config['align']['aligner']]['align_command'] % (
                get_path('fasta'),
            )

            cmds = cmds % (
                load_module_command('samtools', config),
                load_module_command('star', config),
                tool_command % (my_index)
            )
        else:
            raise Exception("Invalid aligner specified in config!")

        shell(cmds)

rule align_pe_CONSENSUS:
    input:
        "{analysis_dir}/reads_parsed/fastq/consensus/{sample_full}__MIP_TRIMMED_CONSENSUS__R1_001.fastq.gz",
        "{analysis_dir}/reads_parsed/fastq/consensus/{sample_full}__MIP_TRIMMED_CONSENSUS__R2_001.fastq.gz",
        get_alignment_index_filename(real_path = True),
        "{analysis_dir}/reads_parsed/stats_samples.csv", #to make sure this gets generated and checked first
        "{analysis_dir}/versions/%s.txt" % config['align']['aligner'],
        "{analysis_dir}/versions/samtools.txt",
    output:
        #note the `,.*` to allow empty subdir
        "{analysis_dir}/bams_consensus/{sample_full}_CONSENSUS.bam",
        "{analysis_dir}/bams_consensus/{sample_full}_CONSENSUS.bam.bai",
    threads: 4
    run:
        cmds = """
        %s
        %s
        %s
        samtools index {output[0]:q};
        """

        my_index = get_alignment_index_filename(real_path = False)
        if config['align']['aligner'] == 'bwa':
            cmds = cmds % (
                load_module_command('samtools', config),
                load_module_command('bwa', config),
                """
                bwa mem \
                    -t {threads} \
                    -R '@RG\\tID:{wildcards.sample_full}\\tSM:{wildcards.sample_full}' \
                    -C \
                    "%s" \
                    {input[0]:q} {input[1]:q} \
                | samtools fixmate -O bam - - \
                | samtools sort -T "{output[0]}__sort_tmp" \
                    -o {output[0]:q} - \
                ;
                """ % my_index)
        elif config['align']['aligner'] == 'bowtie2':
            extra = ''
            if config['align']['bowtie2']['report_n'] > 1:
                extra += '-k %d' % config['align']['bowtie2']['report_n']

            cmds = cmds % (
                load_module_command('samtools', config),
                load_module_command('bowtie2', config),
                """
                bowtie2 \
                    %s \
                    --rg-id {wildcards.sample_full:q} \
                    --rg 'SM:{wildcards.sample_full}' \
                    --threads {threads} \
                    -x "%s" \
                    -1 {input[0]:q} \
                    -2 {input[1]:q} \
                | samtools sort -T "{output[0]}__sort_tmp" \
                    -o {output[0]:q} - \
                ;
                """ % (extra, my_index))
        elif config['align']['aligner'] == 'star':
            extra = ''
            if input[0].endswith('.gz'):
                extra += ' --readFilesCommand "gzip -cd"'

            cmds = cmds % (
                load_module_command('samtools', config),
                load_module_command('star', config),
                """
                STAR \
                    %s \
                    --runMode alignReads \
                    --runThreadN {threads} \
                    --outSAMattrRGline ID:{wildcards.sample_full:q} SM:{wildcards.sample_full:q} \
                    --genomeDir "%s" \
                    --readFilesIn {input[0]:q} {input[1]:q} \
                    --outSAMtype BAM SortedByCoordinate \
                    --outFileNamePrefix "{output[0]}." \
                    --outSAMunmapped Within \
                    --outFilterType BySJout \
                    --outFilterMultimapNmax 5 \
                ;

                #fix names -- this will fail if no reads but should handle only unmapped thanks to --outSAMunmapped
                mv "{output[0]}.Aligned.sortedByCoord.out.bam" {output[0]:q};
                """ % (extra, my_index))
        elif config['align']['aligner'] in config['tools'] and 'align_command' in config['tools'][config['align']['aligner']]:
            tool_command = config['tools'][config['align']['aligner']]['align_command'] % (
                get_path('fasta'),
            )

            cmds = cmds % (
                load_module_command('samtools', config),
                load_module_command('star', config),
                tool_command % (my_index)
            )
        else:
            raise Exception("Invalid aligner specified in config!")

        shell(cmds)

rule stats_alignment:
    input:
        "{analysis_dir}/bams/{sample_full}.bam",
        "{analysis_dir}/bams/{sample_full}.bam.bai",
        probes = "{analysis_dir}/probes.csv",
    output:
        "{analysis_dir}/stats_alignment/{sample_full}.stats_alignment.csv"
    run:
        import amplimap.stats_alignment
        amplimap.stats_alignment.process_file(
            probes_path = input['probes'],
            #targets_path = input['targets'],
            input_path = input[0],
            output_path = output[0],
            min_mapq = config['pileup']['min_mapq'],
            min_consensus_count = config['general']['umi_min_consensus_count'],
            include_primers = not config['parse_reads']['trim_primers'],
            use_naive_groups = not config['general']['ignore_umis'],
            ignore_groups = config['general']['ignore_umis']
        )

rule stats_alignment_agg:
    input:
        csvs = expand("{{analysis_dir}}/stats_alignment/{sample_full}.stats_alignment.csv", sample_full = SAMPLES)
    output:
        "{analysis_dir}/stats_alignment/stats_alignment.csv"
    run:
        import amplimap.stats_alignment
        amplimap.stats_alignment.aggregate(folder = "{}/stats_alignment/".format(wildcards['analysis_dir']))

rule do_pileup:
    input:
        "{analysis_dir}/bams/{sample_full}.bam", # if config['general']['ignore_umis'] else rules.umi_group.output,
        "{analysis_dir}/bams/{sample_full}.bam.bai",
        targets = "{analysis_dir}/targets.bed",
        probes = "{analysis_dir}/probes.csv" if config['pileup']['validate_probe_targets'] and not config['general']['use_raw_reads'] else []
    output:
        "{analysis_dir}/pileups/{sample_full}.pileup.csv",
        "{analysis_dir}/pileups/{sample_full}.targets.csv"
    run:
        import amplimap.pileup
        amplimap.pileup.process_file(
            input = input[0],
            output = "{}/pileups/{}".format(wildcards['analysis_dir'], wildcards['sample_full']),
            reference_type = 'genome',
            subsample_reads = config['pileup']['subsample_reads'],
            probes_file = input['probes'] if config['pileup']['validate_probe_targets'] and not config['general']['use_raw_reads'] else None,
            snps_file = None, #different from other pileup command
            targets_file = input['targets'], #different from other pileup command
            validate_probe_targets = config['pileup']['validate_probe_targets'] and not config['general']['use_raw_reads'],
            filter_softclipped = config['pileup']['filter_softclipped'],
            fasta_file = get_path('fasta'),
            min_mapq = config['pileup']['min_mapq'],
            min_baseq = config['pileup']['min_baseq'],
            ignore_groups = config['general']['ignore_umis'],
            group_with_mate_positions = config['pileup']['group_with_mate_positions'],
            min_consensus_count = config['general']['umi_min_consensus_count'],
            min_consensus_fraction = config['general']['umi_min_consensus_percentage'] / 100.0,

            #to allow running straight from tagged bams:
            no_probe_data = config['general']['use_raw_reads'],
            umi_tag_name = config['general']['umi_tag_name'],
        )

rule pileup_agg:
    input:
        csvs = expand("{{analysis_dir}}/pileups/{sample_full}.pileup.csv", sample_full = SAMPLES),
        targets = "{analysis_dir}/targets.bed"
    output:
        "{analysis_dir}/pileups/pileups_long.csv",
        "{analysis_dir}/pileups/pileups_long_detailed.csv",
        "{analysis_dir}/pileups/pileups_wide.csv",
        "{analysis_dir}/pileups/target_coverage.csv",
        "{analysis_dir}/pileups/target_coverage_long.csv"
    run:
        import amplimap.pileup
        amplimap.pileup.aggregate(
            folder = "{}/pileups/".format(wildcards['analysis_dir']),
            snps_file = None,
            ref = True,
            generate_calls = False
        )

rule do_pileup_snps:
    input:
        "{analysis_dir}/bams/{sample_full}.bam", # if config['general']['ignore_umis'] else rules.umi_group.output,
        "{analysis_dir}/bams/{sample_full}.bam.bai",
        snps = "{analysis_dir}/snps.txt",
        probes = "{analysis_dir}/probes.csv" if config['pileup']['validate_probe_targets'] and not config['general']['use_raw_reads'] else []
    output:
        "{analysis_dir}/pileups_snps/{sample_full}.pileup.csv"
    run:
        import amplimap.pileup
        amplimap.pileup.process_file(
            input = input[0],
            output = "{}/pileups_snps/{}".format(wildcards['analysis_dir'], wildcards['sample_full']),
            reference_type = 'genome',
            subsample_reads = config['pileup']['subsample_reads'],
            probes_file = input['probes'] if config['pileup']['validate_probe_targets'] and not config['general']['use_raw_reads'] else None,
            snps_file = input['snps'], #different from other pileup command
            targets_file = None, #different from other pileup command
            validate_probe_targets = config['pileup']['validate_probe_targets'] and not config['general']['use_raw_reads'],
            filter_softclipped = config['pileup']['filter_softclipped'],
            fasta_file = get_path('fasta'),
            min_mapq = config['pileup']['min_mapq'],
            min_baseq = config['pileup']['min_baseq'],
            ignore_groups = config['general']['ignore_umis'],
            group_with_mate_positions = config['pileup']['group_with_mate_positions'],
            min_consensus_count = config['general']['umi_min_consensus_count'],
            min_consensus_fraction = config['general']['umi_min_consensus_percentage'] / 100.0,

            #to allow running straight from tagged bams:
            no_probe_data = config['general']['use_raw_reads'],
            umi_tag_name = config['general']['umi_tag_name'],
        )

rule pileup_snps_agg:
    input:
        csvs = expand("{{analysis_dir}}/pileups_snps/{sample_full}.pileup.csv", sample_full = SAMPLES),
        snps = "{analysis_dir}/snps.txt"
    output:
        "{analysis_dir}/pileups_snps/target_snps_pileups_long.csv",
        "{analysis_dir}/pileups_snps/target_snps_pileups_long_detailed.csv",
        "{analysis_dir}/pileups_snps/target_snps_pileups_wide.csv"
    run:
        import amplimap.pileup
        amplimap.pileup.aggregate(
            folder = "{}/pileups_snps/".format(wildcards['analysis_dir']),
            snps_file = input['snps'],
            ref = True,
            generate_calls = False
        )

rule umi_dedup:
    input:
        "{analysis_dir}/bams/{sample_full}.bam",
        "{analysis_dir}/bams/{sample_full}.bam.bai",
        "{analysis_dir}/versions/samtools.txt",
    output:
        "{analysis_dir}/bams_umi_dedup/{sample_full}.bam",
        "{analysis_dir}/bams_umi_dedup/{sample_full}.bam.bai",
        temp("{analysis_dir}/bams_umi_dedup/{sample_full}.unsorted.bam")
    shell:
        """
        %s
        umi_tools dedup \
            --log2stderr \
            --method directional \
            --paired \
            -I {input[0]:q} \
            -S {output[2]:q} \
        ;
        samtools sort -o {output[0]:q} {output[2]:q};
        samtools index {output[0]:q};
        """ % load_module_command('samtools', config)

rule calc_coverage:
    input:
        "{analysis_dir}/{bamdir}/{file}.bam",
        "{analysis_dir}/{bamdir}/{file}.bam.bai",
        "{analysis_dir}/versions/bedtools.txt",
        targets = "{analysis_dir}/targets.bed",
    output:
        "{analysis_dir}/{bamdir}/coverages/{file}.coverage_raw.txt",
    shell:        
        """
        %s
        bedtools coverage \
            -a {input.targets:q} \
            -b {input[0]:q} \
            -d \
        > {output[0]:q};
        """ % (load_module_command('bedtools', config))

rule calc_coverage_mapq:
    input:
        "{analysis_dir}/{bamdir}/{file}.bam",
        "{analysis_dir}/{bamdir}/{file}.bam.bai",
        "{analysis_dir}/versions/bedtools.txt",
        targets = "{analysis_dir}/targets.bed",
    output:
        "{analysis_dir}/{bamdir}/coverage_mapq/{file}.coverage_raw.txt",
    shell:
        """
        %s
        %s
        samtools view -b \
            -q {config[pileup][min_mapq]:q} \
            {input[0]:q} \
        | \
        bedtools coverage \
            -a {input.targets:q} \
            -b stdin \
            -d \
        > {output[0]:q};
        """ % (load_module_command('samtools', config), load_module_command('bedtools', config))

rule coverage_process:
    input:
        "{prefix}.coverage_raw.txt"
    output:
        "{prefix}.coverage.csv"
    run:
        import amplimap.coverage
        amplimap.coverage.process_file(input[0], output[0])

rule coverage_agg:
    input:
        csvs = expand("{{bamdir}}/coverages/{sample_full}.coverage.csv", sample_full = SAMPLES),
        sample_info = ["sample_info.csv"] if os.path.isfile("sample_info.csv") else []
    output:
        merged = "{bamdir}/coverages/coverage_full.csv",
        min_coverage = "{bamdir}/coverages/min_coverage.csv",
        cov_per_bp = "{bamdir}/coverages/cov_per_bp.csv",
        fraction_zero_coverage = "{bamdir}/coverages/fraction_zero_coverage.csv",
    run:
        import amplimap.coverage
        amplimap.coverage.aggregate(input, output)

rule call_variants_raw:
    input:
        "{analysis_dir}/bams/{sample_full}.bam",
        "{analysis_dir}/bams/{sample_full}.bam.bai",
        "{analysis_dir}/versions/%s.txt" % config['variants']['caller'],
        targets = "{analysis_dir}/targets_merged.bed",
    output:
        "{analysis_dir}/variants_raw/{sample_full}.vcf"
    run:
        #same as below
        if config['variants']['caller'] == 'platypus':
            shell("""
                %s
                platypus callVariants \
                    -o {output[0]:q} \
                    --refFile "%s" \
                    --bamFiles {input[0]:q} \
                    --regions {input.targets:q} \
                    {config[variants][platypus][parameters]} \
                ;
                """ % (load_module_command('platypus', config), get_path('fasta'))
            )
        elif config['variants']['caller'] == 'gatk':   
            #NB: We are not using duplicate marking, BQSR and VQSR here. this may result in suboptimal variant calls and increased false positives.
            #However, none of these tools seem to be suitable for targeted sequencing data:
            #- duplicate calling: we actually expect to see duplicates, since all reads are generated with the same primers and read length.
            #- BQSR, VQSR: we won't have a large genome-wide set of known SNPs to work with.
            #Nevertheless, it may be worth running them manually and comparing results.
            shell("""
                %s
                gatk HaplotypeCaller \
                    {config[variants][gatk][parameters]} \
                    --reference "%s" \
                    --intervals {input.targets:q} \
                    --input {input[0]:q} \
                    --output {output[0]:q} \
                ;
                """ % (load_module_command('gatk', config), get_path('fasta'))
            )
        elif config['variants']['caller'] in config['tools'] and 'call_command' in config['tools'][config['variants']['caller']]:
            tool_command = config['tools'][config['variants']['caller']]['call_command'] % (
                get_path('fasta'),
            )

            shell("""
                %s
                %s
                """ % (load_module_command(config['variants']['caller'], config), tool_command)
            )
        else:
            raise Exception('Invalid variant caller specified')

rule call_variants_umi:
    input:
        rules.umi_dedup.output,
        targets = "{analysis_dir}/targets_merged.bed",
        version = "{analysis_dir}/versions/%s.txt" % config['variants']['caller'],
    output:
        "{analysis_dir}/variants_umi/{sample_full}.vcf"
    run:
        #same as above
        if config['variants']['caller'] == 'platypus':
            shell("""
                %s
                platypus callVariants \
                    -o {output[0]:q} \
                    --refFile "%s" \
                    --bamFiles {input[0]:q} \
                    --regions {input.targets:q} \
                    {config[variants][platypus][parameters]} \
                ;
                """ % (load_module_command('platypus', config), get_path('fasta'))
            )
        elif config['variants']['caller'] == 'gatk':   
            shell("""
                %s
                gatk HaplotypeCaller \
                    {config[variants][gatk][parameters]} \
                    --reference "%s" \
                    --intervals {input.targets:q} \
                    --input {input[0]:q} \
                    --output {output[0]:q} \
                ;
                """ % (load_module_command('gatk', config), get_path('fasta'))
            )
        elif config['variants']['caller'] in config['tools'] and 'call_command' in config['tools'][config['variants']['caller']]:
            tool_command = config['tools'][config['variants']['caller']]['call_command'] % (
                get_path('fasta'),
            )

            shell("""
                %s
                %s
                """ % (load_module_command(config['variants']['caller'], config), tool_command)
            )
        else:
            raise Exception('Invalid variant caller specified')

rule call_variants_low_frequency:
    input:
        "{analysis_dir}/bams/{sample_full}.bam",
        "{analysis_dir}/bams/{sample_full}.bam.bai",
        targets = "{analysis_dir}/targets_merged.bed",
        version = "{analysis_dir}/versions/%s.txt" % config['variants']['caller_low_frequency'],
    output:
        "{analysis_dir}/variants_low_frequency/{sample_full}.vcf",
        "{analysis_dir}/variants_low_frequency_unfiltered/{sample_full}.vcf"
    run:
        #same as below
        if config['variants']['caller_low_frequency'] == 'mutect2':
            shell("""
                %s
                gatk Mutect2 \
                    {config[variants][mutect2][parameters]} \
                    --reference "%s" \
                    --intervals {input.targets:q} \
                    --input {input[0]:q} \
                    --output {output[1]:q} \
                    --tumor-sample {wildcards.sample_full:q} \
                ;

                gatk FilterMutectCalls \
                    --variant {output[1]:q} \
                    --output {output[0]:q} \
                ;
                """ % (load_module_command('gatk', config), get_path('fasta'))
            )
        else:
            raise Exception('Invalid variant caller specified')

rule normalize_variants:
    input:
        "{analysis_dir}/{variants_dir}/{sample_full}.vcf",
        "{analysis_dir}/versions/bcftools.txt",
    output:
        "{analysis_dir}/{variants_dir}/{sample_full}.vcf.normalized.vcf.gz",
    run:
        lines_before = 0
        with open(input[0], 'rb') as f:
            for line in f:
                if not line.startswith(b'#'):
                    lines_before += 1

        if lines_before == 0:
            print('No variants found, creating empty file!')
            open(output[0], 'a').close()
        else:
            shell("""
                %s
                bcftools norm \
                    --multiallelics=-both \
                    --fasta-ref="%s" \
                    --output={output[0]:q} \
                    --output-type=z \
                    {input[0]:q} \
                ;
            """ % (load_module_command('bcftools', config), get_path('fasta')))
            
            lines_after = 0
            import gzip
            with gzip.open(output[0], 'rb') as f:
                for line in f:
                    if not line.startswith(b'#'):
                        lines_after += 1

            print('Converted raw VCF with {} lines into normalized VCF with {} lines'.format(
                lines_before,
                lines_after
            ))
            
            assert lines_after >= lines_before, 'Error: Got fewer lines after VCF normalization than before for {}'.format(input[0])


rule annotate_variants:
    input:
        "{analysis_dir}/{variants_dir}/{sample_full}.vcf.normalized.vcf.gz",
        "{analysis_dir}/versions/annovar.txt",
    output:
        "{analysis_dir}/{variants_dir}/{sample_full}.vcf.normalized.vcf.gz."+get_annovar_name()+"_multianno.csv",
        temp("{analysis_dir}/{variants_dir}/{sample_full}.for_annovar.bed")
    run:
        import gzip
        lines_before = 0
        with gzip.open(input[0], 'rb') as f:
            for line in f:
                if not line.startswith(b'#'):
                    lines_before += 1

        if lines_before == 0:
            print('No variants found, creating empty files!')
            open(output[0], 'a').close()
            open(output[1], 'a').close()
        else:
            shell("""
                %s
                convert2annovar.pl {input[0]:q} -format vcf4old --includeinfo > {output[1]:q};

                table_annovar.pl \
                    {output[1]:q} \
                    "%s" \
                    --buildver "%s" \
                    --otherinfo -csvout -remove \
                    -protocol {config[annotate][annovar][protocols]:q} \
                    -operation {config[annotate][annovar][operations]:q} \
                    --outfile {input[0]:q} \
                ;
            """ % (load_module_command('annovar', config), get_path('annovar'), get_annovar_name()))
            
            lines_after = -1  # adjust for header line
            with open(output[0], 'rb') as f:
                for line in f:
                    lines_after += 1

            print('Annotated VCF with {} lines into CSV with {} lines'.format(
                lines_before,
                lines_after
            ))
            
            assert lines_after >= lines_before, 'Error: Got fewer lines after VCF annotation than before for {}'.format(input[0])

# only provide this rule if we have annovar
if get_path('annovar', default='') != '':
    rule merge_variants_from_annovar:
        input:
            csvs = expand("{{analysis_dir}}/{{variants_dir}}/{sample_full}.vcf.normalized.vcf.gz."+get_annovar_name()+"_multianno.csv", sample_full = SAMPLES)
        output:
            "{analysis_dir}/{variants_dir}/variants_merged.csv"
        run:
            import amplimap.variants
            amplimap.variants.merge_variants_from_annovar(input, output)

# otherwise we just merge the files together
else:
    rule variants_merge_unannotated:
        input:
            vcfs = expand("{{analysis_dir}}/{{variants_dir}}/{sample_full}.vcf.normalized.vcf.gz", sample_full = SAMPLES)
        output:
            "{analysis_dir}/{variants_dir}/variants_merged.csv"
        run:
            import amplimap.variants
            amplimap.variants.merge_variants_unannotated(input['vcfs'], output[0])

rule variants_summary:
    input:
        merged = "{analysis_dir}/{variants_dir}/variants_merged.csv",
        targets = "{analysis_dir}/targets.bed",
        sample_info = ["sample_info.csv"] if os.path.isfile("sample_info.csv") else []
    output:
        "{analysis_dir}/{variants_dir}/variants_summary.csv"
    run:
        import amplimap.variants
        amplimap.variants.make_summary(input, output, config,
            get_path('exon_table') if 'include_exon_distance' in config['annotate'] and config['annotate']['include_exon_distance'] else None)

rule variants_summary_condensed:
    input:
        summary = rules.variants_summary.output[0],
        sample_info = ["sample_info.csv"] if os.path.isfile("sample_info.csv") else []
    output:
        filtered = "{analysis_dir}/{variants_dir}/variants_summary_filtered.csv",
        unfiltered = "{analysis_dir}/{variants_dir}/variants_summary_UNfiltered.csv"
    run:
        import amplimap.variants
        amplimap.variants.make_summary_condensed(input, output)

rule variants_summary_excel:
    input:
        rules.variants_summary.output[0]
    output:
        "{analysis_dir}/{variants_dir}/variants_summary.xlsx"
    run:
        import amplimap.variants
        amplimap.variants.make_summary_excel(input, output)

rule check_test_variant:
    output:
        touch("checks/check_test__{search}_{replace}_{percentage}.done")
    run:
        import amplimap.simulate
        amplimap.simulate.check_parameters(wildcards)

rule test_make_reads:
    input:
        "analysis/reads_in/{sample_with_lane}_R1_001.fastq.gz",
        "analysis/reads_in/{sample_with_lane}_R2_001.fastq.gz",
        "checks/check_test__{search}_{replace}_{percentage}.done",
    output:
        "test__{search}_{replace}_{percentage}/reads_in/{sample_with_lane}_R1_001.fastq.gz",
        "test__{search}_{replace}_{percentage}/reads_in/{sample_with_lane}_R2_001.fastq.gz",
        stats = "test__{search}_{replace}_{percentage}/stats_replacements/stats_replacements__sample_{sample_with_lane}.csv"
    run:
        import amplimap.simulate
        amplimap.simulate.make_simulated_reads(input, output, wildcards, config)

rule stats_replacements_agg:
    input:
        expand("test__{{params}}/stats_replacements/stats_replacements__sample_{sample_full}_L{lane}.csv", sample_full = SAMPLES, lane = ['%03d' % l for l in range(1, config['general']['lanes_actual']+1)])
    output:
        "test__{params}/stats_replacements/stats_replacements.csv",
    run:
        import amplimap.simulate
        amplimap.simulate.stats_replacements_agg(input, output)
