# top-level targets:
# - download_reads
# - trim_reads
# - smash_reads
# - gather_genbank
# - download_matching_genomes
# - map_reads
# - summarize

import glob, os, csv

# we have "semi-wildcard" rules that depend on gather output only - they
# all depend on having the GATHER_CSV present.
SAMPLE=config['sample']
print(f'sample: {SAMPLE}')

outdir = config.get('outdir', 'outputs/')
outdir = outdir.rstrip('/')
ABUNDTRIM_MEMORY=10e9

sourmash_db = config.get('db', 'all-gather-genomes.sbt.zip')   # CTB add to conf

GATHER_CSV=f'{outdir}/genbank/{SAMPLE}.x.genbank.gather.csv'

genome_accs = []
if os.path.exists(GATHER_CSV):
    print(f'gather output {GATHER_CSV} exists!')
    with open(GATHER_CSV, 'rt') as fp:
       r = csv.DictReader(fp)
       for row in r:
           acc = row['name'].split(' ')[0]
           genome_accs.append(acc)
    print(f'loaded {len(genome_accs)} accessions from gather results.')
else:
    print(f'{GATHER_CSV} does not exist.')


###

wildcard_constraints:
    size="\d+",
    sra_id='[a-zA-Z0-9._-]+'                   # should be everything but /

rule all:
    input:
        GATHER_CSV

rule download_reads:
    input:
        f"{outdir}/raw/{SAMPLE}_1.fastq.gz",
        f"{outdir}/raw/{SAMPLE}_2.fastq.gz"

rule trim_reads:
    input:
        url_file = f"{outdir}/abundtrim/{SAMPLE}.abundtrim.fq.gz"

rule smash_reads:
    input:
        url_file = f"{outdir}/sigs/{SAMPLE}.abundtrim.sig"

rule gather_genbank:
    input:
        url_file = GATHER_CSV

rule download_matching_genomes:
    input:
        genome = expand(outdir + "/genomes/{acc}.fna.gz",
                        acc=genome_accs)

rule map_reads:
    input:
        f"{outdir}/minimap/depth/{SAMPLE}.summary.csv",
        f"{outdir}/leftover/depth/{SAMPLE}.summary.csv"

rule summarize:
    input:
        outdir + f'/reports/report-{SAMPLE}.html'

# print out the configuration
rule showconf:
    run:
        import yaml
        print('# full aggregated configuration:')
        print(yaml.dump(config).strip())
        print('# END')


rule zip:
    shell: """
        zip -r transfer.zip {outdir}/leftover/depth/*.summary.csv \
                {outdir}/minimap/depth/*.summary.csv \
                {outdir}/*.gather.csv {outdir}/genbank/*.csv
    """


# wildcard rule for downloading SRA IDs.
rule wc_download_sra:
    output:
        r1 = protected(outdir + "/raw/{sra_id}_1.fastq.gz"),
        r2 = protected(outdir + "/raw/{sra_id}_2.fastq.gz"),
    conda: "env/sra.yml"
    shell: '''
        fastq-dump --skip-technical  \
               --readids \
               --read-filter pass \
               --dumpbase \
               --split-spot \
               --clip \
               -Z \
               {wildcards.sra_id} | \
               perl -ne 's/\.([12]) /\/$1 /; print $_' | \
               split-paired-reads.py --gzip -1 {output.r1} -2 {output.r2}
        '''

# wildcard rule for adapter trimming
rule adapter_trim:
    input:
        r1 = outdir + "/raw/{sample}_1.fastq.gz",
        r2 = outdir + "/raw/{sample}_2.fastq.gz",
        adapters = "inputs/adapters.fa"
    output:
        r1 = protected(outdir + '/trim/{sample}_R1.trim.fq.gz'),
        r2 = protected(outdir + '/trim/{sample}_R2.trim.fq.gz'),
        o1 = protected(outdir + '/trim/{sample}_o1.trim.fq.gz'),
        o2 = protected(outdir + '/trim/{sample}_o2.trim.fq.gz'),
    conda: 'env/trim.yml'
    shell: """
        trimmomatic PE {input.r1} {input.r2} \
             {output.r1} {output.o1} {output.r2} {output.o2} \
             ILLUMINACLIP:{input.adapters}:2:0:15 MINLEN:25  \
             LEADING:2 TRAILING:2 SLIDINGWINDOW:4:2
    """

# wildcard rule for k-mer abundance trimming
rule kmer_trim_reads:
    input: 
        outdir + "/trim/{sample}_R1.trim.fq.gz", 
        outdir + "/trim/{sample}_R2.trim.fq.gz"
    output:
        protected(outdir + "/abundtrim/{sample}.abundtrim.fq.gz")
    conda: 'env/trim.yml'
    shell: """
        interleave-reads.py {input} | 
            trim-low-abund.py -C 3 -Z 18 -M {ABUNDTRIM_MEMORY} -V - -o {output}
    """

# wildcard rule for mapping abundtrim reads and producing a bam
rule minimap:
    output:
        bam = outdir + "/minimap/{sra_id}.x.{acc}.bam",
    input:
        query = outdir + "/genomes/{acc}.fna.gz",
        metagenome = outdir + "/abundtrim/{sra_id}.abundtrim.fq.gz",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        minimap2 -ax sr -t {threads} {input.query} {input.metagenome} | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

# wildcard rule for extracting FASTQ from BAM
rule samtools_fastq:
    output:
        mapped = outdir + "/minimap/{bam}.mapped.fq.gz",
    input:
        bam = outdir + "/minimap/{bam}.bam",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        samtools bam2fq {input.bam} | gzip > {output.mapped}
    """

# wildcard rule for getting per-base depth information from BAM
rule samtools_depth:
    input:
        bam = outdir + "/{dir}/{bam}.bam",
    output:
        depth = outdir + "/{dir}/depth/{bam}.txt",
    conda: "env/minimap2.yml"
    shell: """
        samtools depth -aa {input.bam} > {output.depth}
    """

# SEMI-WILDCARD rule for summarize depth into a CSV
rule summarize_samtools_depth:
    output: f"{outdir}/{{dir}}/depth/{SAMPLE}.summary.csv"
    input:
        expand(outdir + "/{{dir}}/depth/{s}.x.{g}.txt",
               s=SAMPLE, g=genome_accs)
    run:
        import pandas as pd

        runs = {}
        for n, sra_stat in enumerate(input):
            print(f'reading from {sra_stat} - {n+1}/{len(input)}...')
            data = pd.read_table(sra_stat, names=["contig", "pos", "coverage"])
            sra_id = sra_stat.split("/")[-1].split(".")[0]
            genome_id = sra_stat.split("/")[-1].split(".")[2]

            d = {}
            value_counts = data['coverage'].value_counts()
            d['genome bp'] = int(len(data))
            d['missed'] = int(value_counts.get(0, 0))
            d['percent missed'] = 100 * d['missed'] / d['genome bp']
            d['coverage'] = data['coverage'].sum() / len(data)
            if d['missed'] != 0:
                uniq_cov = d['coverage'] / (1 - d['missed'] / d['genome bp'])
                d['unique_mapped_coverage'] = uniq_cov
            else:
                d['unique_mapped_coverage'] = d['coverage']
            d['covered_bp'] = (1 - d['percent missed']/100.0) * d['genome bp']
            d['genome_id'] = genome_id
            d['sample_id'] = sra_id
            runs[genome_id] = d

        pd.DataFrame(runs).T.to_csv(output[0])

# wildcard rule for computing sourmash signature from abundtrim reads
rule sourmash_reads_genbank:
    input:
        metagenome = outdir + "/abundtrim/{sra_id}.abundtrim.fq.gz",
    output:
        sig = outdir + "/sigs/{sra_id}.abundtrim.sig"
    conda: "env/sourmash.yml"
    shell: """
        sourmash compute -k 21,31,51 --scaled=1000 {input} -o {output} \
           --name {wildcards.sra_id} --track-abundance
    """

# wildcard rule for running sourmash gather on abundtrim read signature
rule sourmash_gather_reads:
    input:
        sig = outdir + "/sigs/{sra_id}.abundtrim.sig",
        db = sourmash_db,
    output:
        csv = outdir + "/{sra_id}.gather.csv",
        out = outdir + "/{sra_id}.gather.out",
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} {input.db} -o {output.csv} > {output.out}
    """

# NON-WILDCARD rule for configuring ipython kernel for papermill
rule set_kernel:
    output:
        f"{outdir}/.kernel.set"
    conda: 'env/papermill.yml'
    shell: """
        python -m ipykernel install --user --name genome_grist
        python -m pip install matplotlib numpy pandas
        touch {output}
    """


# NON-WILDCARD rule for papermill -> reporting notebook + html
rule make_notebook:
    input:
        nb = 'genome_grist/notebooks/report-sample.ipynb',
        all_csv = f"{outdir}/minimap/depth/{SAMPLE}.summary.csv",
        depth_csv = f"{outdir}/leftover/depth/{SAMPLE}.summary.csv",
        gather_csv = GATHER_CSV,
        kernel_set = rules.set_kernel.output,
    output:
        nb = outdir + f'/reports/report-{SAMPLE}.ipynb',
        html = outdir + f'/reports/report-{SAMPLE}.html',
    params:
        cwd = outdir + '/reports/'
    conda: 'env/papermill.yml'
    shell: """
        papermill {input.nb} - -k genome_grist \
              -p sample_id {SAMPLE:q} -p render '' \
              --cwd {params.cwd} \
              > {output.nb}
        python -m nbconvert {output.nb} --to html --stdout --no-input \
             --ExecutePreprocessor.kernel_name=genome_grist > {output.html}
    """

# NON-WILDCARD rule for mapped reads to leftover reads
# @CTB update subtract-gather to take sample ID as param
# @CTB update for intersected/overlapping reads too
rule extract_leftover_reads:
    input:
        csv = GATHER_CSV,
        reads = expand(f"{outdir}/minimap/{SAMPLE}.x.{{acc}}.mapped.fq.gz",
                       acc=genome_accs),
    output:
        expand(f"{outdir}/minimap/{SAMPLE}.x.{{acc}}.leftover.fq.gz",
               acc=genome_accs),
    conda: "env/sourmash.yml"
    shell: """
        python -m genome_grist.subtract_gather {input.csv}
    """

# rule for mapping leftover reads to genomes -> BAM
rule map_leftover_reads:
    input:
        all_csv = f"{outdir}/minimap/depth/{{sra_id}}.summary.csv",
        query = f"{outdir}/genomes/{{acc}}.fna.gz",
        reads = outdir + "/minimap/{sra_id}.x.{acc}.leftover.fq.gz",
    output:
        bam=outdir + "/leftover/{sra_id}.x.{acc}.bam",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        minimap2 -ax sr -t {threads} {input.query} {input.reads} | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

# wildcard rule for running sourmash gather x genbank
rule sourmash_gather_reads_genbank:
    input:
        sig = outdir + "/sigs/{sra_id}.abundtrim.sig",
        db = glob.glob('/home/irber/sourmash_databases/outputs/sbt/genbank-*x1e5*k31*')
    output:
        csv = outdir + "/genbank/{sra_id}.x.genbank.gather.csv",
        matches = outdir + "/genbank/{sra_id}.x.genbank.gather.sig",
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} /home/ctbrown/genome-grist/all-gather-genomes.sbt.zip {input.db} -o {output.csv} --save-matches {output.matches}
    """

# wildcard rule for extracting accessions from gather output
rule extract_genome_accs_from_gather:
    input:
        csv = outdir + "/genbank/{sra_id}.x.genbank.gather.csv",
    output:
        acc_file = outdir + "/genbank/{sra_id}.genomes.accs.txt"
    run:
        n = None
        with open(output.acc_file, 'wt') as outfp:
            with open(input.csv, 'rt') as infp:
                r = csv.DictReader(infp)
                for n, row in enumerate(r):
                    acc = row['name'].split(' ')[0]
                    print(acc, file=outfp)

        if n is not None:
            print(f'found {n+1} accessions in {input.csv}')
              

# wildcard rule for getting genome info (URL, name) from genbank
rule get_matching_genome_info:
    input:
        acc_file = outdir + "/genbank/{sra_id}.genomes.accs.txt"
    output:
        info_file = outdir + "/genbank/{sra_id}.genomes.info.csv"
    shell: """
        python -m genome_grist.genbank_genomes {input.acc_file} \
          -o {output.info_file}
    """

# NON-WILDCARD/gather-dependent rule for downloading actual genomes
rule download_matching_genomes_one_by_one:
    input:
        info_file = outdir + f"/genbank/{SAMPLE}.genomes.info.csv"
    output:
        genome = outdir + "/genomes/{acc}.fna.gz"
    run:
        with open(input.info_file, 'rt') as infp:
            r = csv.DictReader(infp)
            for row in r:
                acc = row['acc']
                if acc != wildcards.acc: continue
                url = row['genome_url']
                name = row['ncbi_tax_name']

                print(f"downloading genome for acc {acc}/{name} from NCBI...")
                with open(output.genome, 'wb') as outfp:
                    with urllib.request.urlopen(url) as response:
                        content = response.read()
                        outfp.write(content)
                        print(f"...wrote {len(content)} bytes to {output.genome}")
