import glob, os, csv

SAMPLE=config['sample']
print(f'sample: {SAMPLE}')

outdir = config.get('outdir', 'outputs/')
outdir = outdir.rstrip('/')
print('XXX', outdir)

GATHER_CSV=f'{outdir}/big/{SAMPLE}.x.genbank.gather.csv'

sourmash_db = config.get('db', 'all-gather-genomes.sbt.zip')   # CTB add to conf

genome_accs = []
gather_rows = []
if os.path.exists(GATHER_CSV):
    print(f'{GATHER_CSV} exists!')
    with open(GATHER_CSV, 'rt') as fp:
       r = csv.DictReader(fp)
       for row in r:
           gather_rows.append(row)
           acc = row['name'].split(' ')[0]
           genome_accs.append(acc)
else:
    print(f'{GATHER_CSV} does not exist.')


# load in all the genomes; this could be changed to a more targeted approach!
acc_to_genome = {}
genomes_pattern = f'{outdir}/genomes/*.fna.gz'
for filename in glob.glob(genomes_pattern):   # CTB add to conf
    basename = os.path.basename(filename)
    acc = basename.split('_')[:2]
    acc = '_'.join(acc)
    acc_to_genome[acc] = filename
if not acc_to_genome:
    print(f"couldn't load any genomes?! - check {genomes_pattern}")

# @CTB investigate
if 'GCA_002160645.1' in genome_accs:
    genome_accs.remove('GCA_002160645.1')

def input_acc_to_genome(w):
    return acc_to_genome[w.acc]

###

wildcard_constraints:
    size="\d+",
    sra_id='[a-zA-Z0-9._-]+'                   # should be everything but /

rule all:
    input:
        expand(f"{outdir}/minimap/{SAMPLE}.x.{{acc}}.mapped.fq.gz",
               acc=genome_accs),
        expand(f"{outdir}/leftover/{SAMPLE}.x.{{acc}}.bam",
               acc=genome_accs),
        f"{outdir}/minimap/depth/{SAMPLE}.summary.csv",
        f"{outdir}/leftover/depth/{SAMPLE}.summary.csv"


# top-level targets:
# - download_reads
# - trim_reads
# - smash_reads
# - gather_genbank @CTB not yet done
# - download_matching_genomes
# - map_reads
# - summarize

rule download_reads:
    input:
        url_file = f"{outdir}/raw/{SAMPLE}.fastq.gz"

rule trim_reads:
    input:
        url_file = f"{outdir}/abundtrim/{SAMPLE}.abundtrim.fq.gz"

rule smash_reads:
    input:
        url_file = f"{outdir}/sigs/{SAMPLE}.abundtrim.sig"

rule map_reads:
    input:
        f"{outdir}/minimap/depth/{SAMPLE}.summary.csv",
        f"{outdir}/leftover/depth/{SAMPLE}.summary.csv"

rule summarize:
    input:
        outdir + f'/reports/report-{SAMPLE}.html'

rule download_matching_genomes:
    input:
        csv = GATHER_CSV,
        url_file = f"{outdir}/genbank/{SAMPLE}.genomes.urls.txt"
    output:
        genomes_dir = directory("outputs/genomes/")
    conda: 'env/wget.yml'
    shell: """
        wget -i {input.url_file} -nc -P {output.genomes_dir} --limit-rate=400k
    """

rule extract_genome_accs_from_gather:
    input:
        csv = GATHER_CSV,
    output:
        acc_file = f"{outdir}/genbank/{SAMPLE}.genomes.accs.txt"
    run:
        with open(output.acc_file, 'wt') as fp:
           print("\n".join(genome_accs), file=fp)

rule get_matching_genome_info:
    input:
        csv = GATHER_CSV,
        acc_file = f"{outdir}/genbank/{SAMPLE}.genomes.accs.txt"
    output:
        info_file = f"{outdir}/genbank/{SAMPLE}.genomes.info.csv"
    shell: """
        python -m genome_grist.genbank_genomes {input.acc_file} \
          -o {output.info_file}
    """


rule get_matching_genome_urls:
    input:
        csv = GATHER_CSV,
        info_file = f"{outdir}/genbank/{SAMPLE}.genomes.info.csv"
    output:
        url_file = f"{outdir}/genbank/{SAMPLE}.genomes.urls.txt"
    run:
        from genome_grist.genbank_genomes import url_for_accession
        with open(input.info_file, 'rt') as rfp:
            with open(output.url_file, 'wt') as ofp:
                r = csv.DictReader(rfp)
                for row in r:
                    print(row["genome_url"], file=ofp)

# print out the configuration
rule showconf:
    run:
        import yaml
        print('# full aggregated configuration:')
        print(yaml.dump(config).strip())
        print('# END')


rule zip:
    shell: """
        zip -r transfer.zip {outdir}/leftover/depth/*.summary.csv \
                {outdir}/minimap/depth/*.summary.csv \
                {outdir}/*.gather.csv {outdir}/genbank/*.csv
    """


rule download_sra_general:
    output:
        "outputs/raw/{sra_id}.fastq.gz",
    conda: "env/sra.yml"
    shell: '''
        fastq-dump --skip-technical  \
               --readids \
               --read-filter pass \
               --dumpbase \
               --split-spot \
               --clip \
               -Z \
               {wildcards.sra_id} | \
               perl -ne 's/\.([12]) /\/$1 /; print $_' | \
               gzip > {output}
        '''

rule split_reads:
    input:
        "outputs/raw/{sample}.fastq.gz"
    output:
        r1 = "inputs/raw/{sample}_1.fastq.gz",
        r2 = 'inputs/raw/{sample}_2.fastq.gz',
    conda: 'env/trim.yml'
    shell: '''
        split-paired-reads.py --gzip -1 {output.r1} -2 {output.r2} {input}
    '''

rule adapter_trim:
    input:
        r1 = "inputs/raw/{sample}_1.fastq.gz",
        r2 = 'inputs/raw/{sample}_2.fastq.gz',
        adapters = 'inputs/adapters.fa'
    output:
        r1 = 'outputs/trim/{sample}_R1.trim.fq.gz',
        r2 = 'outputs/trim/{sample}_R2.trim.fq.gz',
        o1 = 'outputs/trim/{sample}_o1.trim.fq.gz',
        o2 = 'outputs/trim/{sample}_o2.trim.fq.gz'
    conda: 'env/trim.yml'
    shell:'''
     trimmomatic PE {input.r1} {input.r2} \
             {output.r1} {output.o1} {output.r2} {output.o2} \
             ILLUMINACLIP:{input.adapters}:2:0:15 MINLEN:25  \
             LEADING:2 TRAILING:2 SLIDINGWINDOW:4:2
    '''

rule kmer_trim_reads:
    input: 
        outdir + "/trim/{sample}_R1.trim.fq.gz", 
        outdir + "/trim/{sample}_R2.trim.fq.gz"
    output: outdir + "/abundtrim/{sample}.abundtrim.fq.gz"
    conda: 'env/trim.yml'
    shell:'''
    interleave-reads.py {input} | 
        trim-low-abund.py -C 3 -Z 18 -M 10e9 -V - -o {output}
    '''

rule minimap:
    output:
        bam="outputs/minimap/{sra_id}.x.{acc}.bam",
    input:
        query = input_acc_to_genome,
        metagenome = outdir + "/abundtrim/{sra_id}.abundtrim.fq.gz",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        minimap2 -ax sr -t {threads} {input.query} {input.metagenome} | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

rule samtools_fastq:
    output:
        mapped="outputs/minimap/{bam}.mapped.fq.gz",
    input:
        bam="outputs/minimap/{bam}.bam",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        samtools bam2fq {input.bam} | gzip > {output.mapped}
    """

rule samtools_depth:
    output:
        depth="outputs/{dir}/depth/{bam}.txt",
    input:
        bam="outputs/{dir}/{bam}.bam",
    conda: "env/minimap2.yml"
    shell: """
        samtools depth -aa {input.bam} > {output.depth}
    """

rule summarize_samtools_depth:
    output: f"{outdir}/{{dir}}/depth/{SAMPLE}.summary.csv"
    input:
        expand("outputs/{{dir}}/depth/{s}.x.{g}.txt",
               s=SAMPLE, g=genome_accs)
    run:
        import pandas as pd

        runs = {}
        for n, sra_stat in enumerate(input):
            print(f'reading from {sra_stat} - {n+1}/{len(input)}...')
            data = pd.read_table(sra_stat, names=["contig", "pos", "coverage"])
            sra_id = sra_stat.split("/")[-1].split(".")[0]
            genome_id = sra_stat.split("/")[-1].split(".")[2]

            d = {}
            value_counts = data['coverage'].value_counts()
            d['genome bp'] = int(len(data))
            d['missed'] = int(value_counts.get(0, 0))
            d['percent missed'] = 100 * d['missed'] / d['genome bp']
            d['coverage'] = data['coverage'].sum() / len(data)
            if d['missed'] != 0:
                uniq_cov = d['coverage'] / (1 - d['missed'] / d['genome bp'])
                d['unique_mapped_coverage'] = uniq_cov
            else:
                d['unique_mapped_coverage'] = d['coverage']
            d['covered_bp'] = (1 - d['percent missed']/100.0) * d['genome bp']
            d['genome_id'] = genome_id
            d['sample_id'] = sra_id
            runs[genome_id] = d

        pd.DataFrame(runs).T.to_csv(output[0])


rule sourmash_reads:
    input:
        metagenome = outdir + "/abundtrim/{sra_id}.abundtrim.fq.gz",
    output:
        sig = outdir + "/sigs/{sra_id}.abundtrim.sig"
    conda: "env/sourmash.yml"
    shell: """
        sourmash compute -k 21,31,51 --scaled=1000 {input} -o {output} \
           --name {wildcards.sra_id} --track-abundance
    """


rule sourmash_gather_reads:
    input:
        sig = outdir + "/sigs/{sra_id}.abundtrim.sig",
        db = sourmash_db,
    output:
        csv = outdir + "/{sra_id}.gather.csv",
        out = outdir + "/{sra_id}.gather.out",
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} {input.db} -o {output.csv} > {output.out}
    """

rule set_kernel:
    output:
        f"{outdir}/.kernel.set"
    conda: 'env/papermill.yml'
    shell: """
        python -m ipykernel install --user --name charcoal
        touch {output}
    """

rule make_notebook:
    input:
        nb='genome_grist/notebooks/report-sample.ipynb',
        kernel_set = rules.set_kernel.output
    output:
        nb = outdir + f'/reports/report-{SAMPLE}.ipynb',
    params:
        cwd = outdir + '/reports/'
    conda: 'env/papermill.yml'
    shell: """
        papermill {input.nb} - -k charcoal \
              -p sample_id {SAMPLE:q} -p render '' \
              --cwd {params.cwd} \
              > {output.nb}
    """

rule make_html:
    input:
        source_nb='genome_grist/notebooks/report-sample.ipynb',
        notebook = outdir + f'/reports/report-{SAMPLE}.ipynb'
    output:
        outdir + f'/reports/report-{SAMPLE}.html'
    conda: 'env/papermill.yml'
    shell: """
        python -m nbconvert {input.notebook} --to html --stdout --no-input --ExecutePreprocessor.kernel_name=charcoal > {output}
    """
     

# @CTB update subtract-gather to take sample ID as param
rule extract_leftover_reads:
    input:
        csv = GATHER_CSV,
        reads = expand(f"{outdir}/minimap/{SAMPLE}.x.{{acc}}.mapped.fq.gz",
                       acc=genome_accs),
    output:
        expand(f"{outdir}/minimap/{SAMPLE}.x.{{acc}}.leftover.fq.gz",
               acc=genome_accs),
    conda: "env/sourmash.yml"
    shell: """
        python -m genome_grist.subtract_gather {input.csv}
    """


rule map_leftover_reads:
    output:
        bam="outputs/leftover/{sra_id}.x.{acc}.bam",
    input:
        query = input_acc_to_genome,
        reads = outdir + "/minimap/{sra_id}.x.{acc}.leftover.fq.gz",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        minimap2 -ax sr -t {threads} {input.query} {input.reads} | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

rule sourmash_gather_reads_podar:
    input:
        sig = outdir + "/sigs/SRR606249.abundtrim.sig",
        db = 'test.sbt.zip'
    output:
        csv = outdir + "/big/SRR606249.x.podar.gather.csv"
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} {input.db} -o {output.csv}
    """

rule sourmash_gather_reads_genbank:
    input:
        sig = outdir + "/sigs/SRR606249.abundtrim.sig",
        db = glob.glob('/home/irber/sourmash_databases/outputs/sbt/genbank-*x1e5*k31*')
    output:
        csv = outdir + "/big/SRR606249.x.genbank.gather.csv",
        matches = outdir + "/big/SRR606249.x.genbank.gather.sig",
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} {input.db} -o {output.csv} --save-matches {output.matches}
    """

rule sourmash_gather_reads_ter1:
    input:
        sig = "/home/tereiter/github/2020-ibd/outputs/sigs/p8808mo11.sig",
        db = glob.glob('/home/irber/sourmash_databases/outputs/sbt/genbank-*x1e5*k31*')
    output:
        csv = outdir + "/big/p8808mo11.x.genbank.gather.csv",
        matches = outdir + "/big/p8808mo11.x.genbank.gather.sig"
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} {input.db} -o {output.csv} --save-matches {output.matches}
    """

rule sourmash_gather_reads_ter2:
    input:
        sig = "/home/tereiter/github/2020-ibd/outputs/sigs/p8808mo9.sig",
        db = glob.glob('/home/irber/sourmash_databases/outputs/sbt/genbank-*x1e5*k31*')
    output:
        csv = outdir + "/big/p8808mo9.x.genbank.gather.csv",
        matches = outdir + "/big/p8808mo9.x.genbank.gather.sig"
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} {input.db} -o {output.csv} --save-matches {output.matches}
    """

rule sourmash_gather_reads_twofoo_head:
    input:
        sig = outdir + "/sigs/twofoo-head.abundtrim.sig",
        db = glob.glob('/home/irber/sourmash_databases/outputs/sbt/genbank-*x1e5*k31*')
    output:
        csv = outdir + "/genbank/twofoo-head.x.genbank.gather.csv",
        matches = outdir + "/genbank/twofoo-head.x.genbank.gather.sig",
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} {input.db} -o {output.csv} --save-matches {output.matches}
    """

rule sourmash_gather_reads_hu:
    input:
        sig = "outputs/sigs/SRR1976948.abundtrim.sig",
        db = glob.glob('/home/irber/sourmash_databases/outputs/sbt/genbank-*x1e5*k31*')
    output:
        csv = "outputs/genbank/SRR1976948.x.genbank.gather.csv",
        matches = "outputs/genbank/SRR1976948.x.genbank.gather.sig",
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} {input.db} -o {output.csv} --save-matches {output.matches}
    """

rule sourmash_gather_reads_hu_remaining:
    input:
        sig = "outputs/sigs/SRR1976948.abundtrim.remaining.sig",
        db = glob.glob('/home/irber/sourmash_databases/outputs/sbt/genbank-*x1e5*k31*')
    output:
        csv = "outputs/genbank/SRR1976948-remaining.x.genbank.gather.csv",
        matches = "outputs/genbank/SRR1976948-remaining.x.genbank.gather.sig",
    conda: "env/sourmash.yml"
    shell: """
        sourmash gather {input.sig} {input.db} -o {output.csv} --save-matches {output.matches}
    """
