# top-level targets:
# - download_reads
# - trim_reads
# - smash_reads
# - summarize_sample_info
# - gather_genbank
# - download_matching_genomes
# - map_reads
# - summarize
# - build_consensus
# - make_sgc_conf

import glob, os, csv, tempfile, shutil, sys

SAMPLES=config['sample']
print(f'sample: {SAMPLES}')

outdir = config.get('outdir', 'outputs/')
outdir = outdir.rstrip('/')
print('outdir:', outdir)

base_tempdir = None
try_temp_locations = config.get('tempdir', [])
for temp_loc in try_temp_locations:
    try:
        base_tempdir = tempfile.mkdtemp(dir=temp_loc)
    except FileNotFoundError:
        pass

if not base_tempdir:
    print(f"Could not create a temporary directory in any of {try_temp_locations}", file=sys.stderr)
    print("Please set 'tempdir' in the config.", file=sys.stderr)
    sys.exit(-1)

ABUNDTRIM_MEMORY = float(config.get('metagenome_trim_memory', '1e9'))

sourmash_db_pattern = config.get('sourmash_database_glob_pattern', 'MUST SPECIFY IN CONFIG')
SOURMASH_DB_LIST = glob.glob(sourmash_db_pattern)
if not SOURMASH_DB_LIST:
    print(f"WARNING: no sourmash databases found!? glob pattern was '{sourmash_db_pattern}'")
    SOURMASH_DB_LIST = ["DB_DOES_NOT_EXIST_BUT_IS_USED"]
SOURMASH_DB_KSIZE = config.get('sourmash_database_ksize', ['31'])
SOURMASH_DATABASE_THRESHOLD_BP = config.get('sourmash_database_threshold_bp',
                                            1e5)
SOURMASH_COMPUTE_KSIZES = config.get('sourmash_compute_ksizes',
                                     ['21', '31',' 51'])
SOURMASH_COMPUTE_SCALED = config.get('sourmash_scaled', '1000')
SOURMASH_COMPUTE_TYPE = config.get('sourmash_sigtype', 'DNA')
assert SOURMASH_COMPUTE_TYPE in ('DNA', 'protein')

PREFETCH_MEMORY = float(config.get('prefetch_memory', '100e9'))

# make a `sourmash sketch` -p param string.
def make_param_str(ksizes, scaled):
    ks = [ f'k={k}' for k in ksizes ]
    ks = ",".join(ks)
    return f"{ks},scaled={scaled},abund"

try:
    os.mkdir('genbank_genomes/')
except:
    pass

###

onsuccess:
    shutil.rmtree(base_tempdir, ignore_errors=True)

onerror:
    shutil.rmtree(base_tempdir, ignore_errors=True)

wildcard_constraints:
    size="\d+",
    sample='[a-zA-Z0-9._-]+'                   # should be everything but /

rule all:
    input:
        expand(f'{outdir}/genbank/{{sample}}.x.genbank.gather.csv',
               sample=SAMPLES)

rule download_reads:
    input:
        expand(f"{outdir}/raw/{{sample}}_1.fastq.gz", sample=SAMPLES),
        expand(f"{outdir}/raw/{{sample}}_2.fastq.gz", sample=SAMPLES),
        expand(f"{outdir}/raw/{{sample}}.raw.sig", sample=SAMPLES),

rule trim_reads:
    input:
        url_file = expand(f"{outdir}/abundtrim/{{sample}}.abundtrim.fq.gz",
                          sample=SAMPLES)

rule estimate_distinct_kmers:
    input:
        url_file = expand(f"{outdir}/abundtrim/{{sample}}.abundtrim.fq.gz.kmer-report.txt",
                          sample=SAMPLES)

rule count_trimmed_reads:
    input:
        url_file = expand(f"{outdir}/abundtrim/{{sample}}.abundtrim.fq.gz.reads-report.txt",
                          sample=SAMPLES)

rule smash_reads:
    input:
        url_file = expand(f"{outdir}/sigs/{{sample}}.abundtrim.sig",
                          sample=SAMPLES)

rule search_genbank:
    input:
        expand(outdir + "/genbank/{sample}.x.genbank.prefetch.zip",
               sample=SAMPLES)

rule summarize_sample_info:
    input:
        expand(outdir + '/{sample}.info.yaml', sample=SAMPLES)

checkpoint gather_genbank:
    input:
        expand(f'{outdir}/genbank/{{sample}}.x.genbank.gather.csv',
               sample=SAMPLES)

checkpoint gather_genbank_wc:
    input:
        gather_csv = f'{outdir}/genbank/{{sample}}.x.genbank.gather.csv'
    output:
        touch(f"{outdir}/.gather.{{sample}}")   # checkpoints need an output ;)

class Checkpoint_GatherResults:
    def __init__(self, pattern):
        self.pattern = pattern

    def get_genome_accs(self, sample):
        gather_csv = f'{outdir}/genbank/{sample}.x.genbank.gather.csv'
        assert os.path.exists(gather_csv)

        genome_accs = []
        with open(gather_csv, 'rt') as fp:
           r = csv.DictReader(fp)
           for row in r:
               acc = row['name'].split(' ')[0]
               genome_accs.append(acc)
        print(f'loaded {len(genome_accs)} accessions from {gather_csv}.')

        return genome_accs

    def __call__(self, w):
        global checkpoints

        # wait for the results of 'gather_genbank_wc'; this will trigger
        # exception until that rule has been run.
        checkpoints.gather_genbank_wc.get(**w)

        # parse hitlist_genomes,
        genome_accs = self.get_genome_accs(w.sample)

        p = expand(self.pattern, acc=genome_accs, **w)
        return p

rule download_matching_genomes_info:
    input:
        expand(f"{outdir}/genbank/{{sample}}.genomes.info.csv",
               sample=SAMPLES)

rule download_matching_genomes:
    input:
        Checkpoint_GatherResults("genbank_genomes/{acc}_genomic.fna.gz"),

rule map_reads:
    input:
        expand(f"{outdir}/minimap/depth/{{sample}}.summary.csv", sample=SAMPLES),
        expand(f"{outdir}/leftover/depth/{{sample}}.summary.csv", sample=SAMPLES)

rule build_consensus:
    input:
        Checkpoint_GatherResults(outdir + f"/minimap/{{sample}}.x.{{acc}}.consensus.fa.gz"),
        Checkpoint_GatherResults(outdir + f"/leftover/{{sample}}.x.{{acc}}.consensus.fa.gz"),

rule summarize:
    input:
        expand(f'{outdir}/reports/report-{{sample}}.html',
               sample=SAMPLES)

rule make_sgc_conf:
    input:
        expand(f"{outdir}/sgc/{{sample}}.conf", sample=SAMPLES)

# print out the configuration
rule showconf:
    run:
        import yaml
        print('# full aggregated configuration:')
        print(yaml.dump(config).strip())
        print('# END')

# check config files only
rule check:
    run:
        pass

rule zip:
    shell: """
        rm -f transfer.zip
        zip -r transfer.zip {outdir}/leftover/depth/*.summary.csv \
                {outdir}/minimap/depth/*.summary.csv \
                {outdir}/genbank/*.csv {outdir}/reports
    """


# download SRA IDs.
rule download_sra_wc:
    output:
        r1  = protected(outdir + "/raw/{sample}_1.fastq.gz"),
        r2  = protected(outdir + "/raw/{sample}_2.fastq.gz"),
        unp = protected(outdir + "/raw/{sample}_unpaired.fastq.gz"),
        # temporary output
        temp_dir = temp(directory(f"{base_tempdir}/{{sample}}.d")),
        temp_r1 =  f"{base_tempdir}/{{sample}}.d/{{sample}}_1.fastq",
        temp_r2 =  f"{base_tempdir}/{{sample}}.d/{{sample}}_2.fastq",
        temp_unp = f"{base_tempdir}/{{sample}}.d/{{sample}}.fastq",
    threads: 6
    conda: "env/sra.yml"
    resources:
        mem_mb=40000,
    shell: '''
        echo tmp directory: {output.temp_dir}
        echo running fasterq-dump for {wildcards.sample}

        fasterq-dump {wildcards.sample} -O {output.temp_dir} \
        -t {output.temp_dir} -e {threads} -p --split-files
        ls -h {output.temp_dir}

        # make unpaired file if needed
        if [ -f {output.temp_r1} -a -f {output.temp_r2} -a \! -f {output.temp_unp} ];
          then
            echo "no unpaired; creating empty unpaired file {output.unp} for simplicity"
            touch {output.temp_unp}
          # make r1, r2 files if needed
        elif [ -f {output.temp_unp} -a \! -f {output.temp_r1} -a \! -f {output.temp_r2} ];
          then
            echo "unpaired file found; creating empty r1 ({output.temp_r1}) and r2 ({output.temp_r2}) files for simplicity"
            touch {output.temp_r1}
            touch {output.temp_r2}
        fi

        # now process the files and move to a permanent location
        echo processing R1...
        seqtk seq -C {output.temp_r1} | \
            perl -ne 's/\.([12])$/\/$1/; print $_' | \
            gzip -c > {output.r1} &

        echo processing R2...
        seqtk seq -C {output.temp_r2} | \
            perl -ne 's/\.([12])$/\/$1/; print $_' | \
            gzip -c > {output.r2} &

        echo processing unpaired...
        seqtk seq -C {output.temp_unp} | \
            perl -ne 's/\.([12])$/\/$1/; print $_' | \
            gzip -c > {output.unp} &
        wait
        echo finished downloading raw reads
        '''

# compute sourmash signature from raw reads
rule smash_raw_reads_wc:
    input:
        r1 = ancient(outdir + "/raw/{sample}_1.fastq.gz"),
        r2 = ancient(outdir + "/raw/{sample}_2.fastq.gz"),
        unp = ancient(outdir + "/raw/{sample}_unpaired.fastq.gz"),
    output:
        sig = outdir + "/raw/{sample}.raw.sig"
    conda: "env/sourmash.yml"
    params:
        param_str = make_param_str(ksizes=SOURMASH_COMPUTE_KSIZES,
                                   scaled=SOURMASH_COMPUTE_SCALED),
        action = "translate" if SOURMASH_COMPUTE_TYPE == "protein" else "dna"
    shell: """
        sourmash sketch {params.action} -p {params.param_str} \
           {input} -o {output} --name 'rawreads:{wildcards.sample}'
    """

# adapter trimming
rule trim_adapters_wc:
    input:
        r1 = ancient(outdir + "/raw/{sample}_1.fastq.gz"),
        r2 = ancient(outdir + "/raw/{sample}_2.fastq.gz"),
    output:
        interleaved = protected(outdir + '/trim/{sample}.trim.fq.gz'),
        json=outdir + "/trim/{sample}.trim.json",
        html=outdir + "/trim/{sample}.trim.html",
    conda: 'env/trim.yml'
    threads: 4
    resources:
        mem_mb=5000,
        runtime_min=600,
    shadow: "shallow"
    shell: """
        fastp --in1 {input.r1} --in2 {input.r2} \
             --detect_adapter_for_pe  --qualified_quality_phred 4 \
             --length_required 25 --correction --thread {threads} \
             --json {output.json} --html {output.html} \
             --low_complexity_filter --stdout | gzip -9 > {output.interleaved}
    """

# adapter trimming for the singleton reads
rule trim_unpaired_adapters_wc:
    input:
        unp = ancient(outdir + "/raw/{sample}_unpaired.fastq.gz"),
    output:
        unp = protected(outdir + '/trim/{sample}_unpaired.trim.fq.gz'),
        json = protected(outdir + '/trim/{sample}_unpaired.trim.json'),
        html = protected(outdir + '/trim/{sample}_unpaired.trim.html'),
    threads: 4
    resources:
        mem_mb=5000,
        runtime_min=600,
    shadow: "shallow"
    conda: 'env/trim.yml'
    shell: """
        fastp --in1 {input.unp} --out1 {output.unp} \
            --detect_adapter_for_se  --qualified_quality_phred 4 \
            --low_complexity_filter --thread {threads} \
            --length_required 25 --correction \
            --json {output.json} --html {output.html}
    """

# k-mer abundance trimming
rule kmer_trim_reads_wc:
    input: 
        interleaved = ancient(outdir + '/trim/{sample}.trim.fq.gz'),
    output:
        protected(outdir + "/abundtrim/{sample}.abundtrim.fq.gz")
    conda: 'env/trim.yml'
    resources:
        mem_mb = int(ABUNDTRIM_MEMORY / 1e6),
    params:
        mem = ABUNDTRIM_MEMORY,
    shell: """
            trim-low-abund.py -C 3 -Z 18 -M {params.mem} -V \
            {input.interleaved} -o {output} --gzip
    """

# count k-mers
rule estimate_distinct_kmers_wc:
    message: """
        Count distinct k-mers for {wildcards.sample} using 'unique-kmers.py' from the khmer package.
    """
    conda: 'env/trim.yml'
    input:
        outdir + "/abundtrim/{sample}.abundtrim.fq.gz"
    output:
        report = outdir + "/abundtrim/{sample}.abundtrim.fq.gz.kmer-report.txt",
    params:
        ksize = SOURMASH_DB_KSIZE,
    shell: """
        unique-kmers.py {input} -k {params.ksize} -R {output.report}
    """

# count reads and bases
rule count_trimmed_reads_wc:
    message: """
        Count reads & bp in trimmed file for {wildcards.sample}.
    """
    input:
        outdir + "/abundtrim/{sample}.abundtrim.fq.gz"
    output:
        report = outdir + "/abundtrim/{sample}.abundtrim.fq.gz.reads-report.txt",
    # from Matt Bashton, in https://bioinformatics.stackexchange.com/questions/935/fast-way-to-count-number-of-reads-and-number-of-bases-in-a-fastq-file
    shell: """
        gzip -dc {input} |
             awk 'NR%4==2{{c++; l+=length($0)}}
                  END{{
                        print "n_reads,n_bases"
                        print c","l
                      }}' > {output.report}
    """

# map abundtrim reads and produce a bam
rule minimap_wc:
    input:
        query = ancient("genbank_genomes/{acc}_genomic.fna.gz"),
        metagenome = outdir + "/abundtrim/{sample}.abundtrim.fq.gz",
    output:
        bam = outdir + "/minimap/{sample}.x.{acc}.bam",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        minimap2 -ax sr -t {threads} {input.query} {input.metagenome} | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

# extract FASTQ from BAM
rule bam_to_fastq_wc:
    input:
        bam = outdir + "/minimap/{bam}.bam",
    output:
        mapped = outdir + "/minimap/{bam}.mapped.fq.gz",
    conda: "env/minimap2.yml"
    shell: """
        samtools bam2fq {input.bam} | gzip > {output.mapped}
    """

# get per-base depth information from BAM
rule bam_to_depth_wc:
    input:
        bam = outdir + "/{dir}/{bam}.bam",
    output:
        depth = outdir + "/{dir}/depth/{bam}.txt",
    conda: "env/minimap2.yml"
    shell: """
        samtools depth -aa {input.bam} > {output.depth}
    """

# wild card rule for getting _covered_ regions from BAM
rule bam_covered_regions_wc:
    input:
        bam = outdir + "/{dir}/{bam}.bam",
    output:
        regions = outdir + "/{dir}/depth/{bam}.regions.bed",
    conda: "env/covtobed.yml"
    shell: """
        covtobed {input.bam} -l 100 -m 1 | \
            bedtools merge -d 5 -c 4 -o mean > {output.regions}
    """

# calculating SNPs/etc.
rule mpileup_wc:
    input:
        query = ancient("genbank_genomes/{acc}_genomic.fna.gz"),
        bam = outdir + "/{dir}/{sample}.x.{acc}.bam",
    output:
        bcf = outdir + "/{dir}/{sample}.x.{acc}.bcf",
        vcf = outdir + "/{dir}/{sample}.x.{acc}.vcf.gz",
        vcfi = outdir + "/{dir}/{sample}.x.{acc}.vcf.gz.csi",
    conda: "env/bcftools.yml"
    shell: """
        genomefile=$(mktemp -t grist.genome.XXXXXXX)
        gunzip -c {input.query} > $genomefile
        bcftools mpileup -Ou -f $genomefile {input.bam} | bcftools call -mv -Ob -o {output.bcf}
        rm $genomefile
        bcftools view {output.bcf} | bgzip > {output.vcf}
        bcftools index {output.vcf}
    """

# build new consensus
rule build_new_consensus_wc:
    input:
        vcf = outdir + "/{dir}/{sample}.x.{acc}.vcf.gz",
        query = ancient("genbank_genomes/{acc}_genomic.fna.gz"),
        regions = outdir + "/{dir}/depth/{sample}.x.{acc}.regions.bed",
    output:
        mask = outdir + "/{dir}/{sample}.x.{acc}.mask.bed",
        genomefile = outdir + "/{dir}/{sample}.x.{acc}.fna.gz.sizes",
        consensus = outdir + "/{dir}/{sample}.x.{acc}.consensus.fa.gz",
    conda: "env/bcftools.yml"
    shell: """
        genomefile=$(mktemp -t grist.genome.XXXXXXX)
        gunzip -c {input.query} > $genomefile
        samtools faidx $genomefile
        cut -f1,2 ${{genomefile}}.fai > {output.genomefile}
        bedtools complement -i {input.regions} -g {output.genomefile} > {output.mask}
        bcftools consensus -f $genomefile {input.vcf} -m {output.mask} | \
            gzip > {output.consensus}
        rm $genomefile
    """

# summarize depth into a CSV
rule summarize_samtools_depth_wc:
    input:
        Checkpoint_GatherResults(outdir + f"/{{dir}}/depth/{{sample}}.x.{{acc}}.txt")
    output:
        f"{outdir}/{{dir}}/depth/{{sample}}.summary.csv"
    run:
        import pandas as pd

        runs = {}
        for n, sra_stat in enumerate(input):
            print(f'reading from {sra_stat} - {n+1}/{len(input)}...')
            data = pd.read_table(sra_stat, names=["contig", "pos", "coverage"])
            sample = sra_stat.split("/")[-1].split(".")[0]
            genome_id = sra_stat.split("/")[-1].split(".")[2]

            d = {}
            value_counts = data['coverage'].value_counts()
            d['genome bp'] = int(len(data))
            d['missed'] = int(value_counts.get(0, 0))
            d['percent missed'] = 100 * d['missed'] / d['genome bp']
            d['coverage'] = data['coverage'].sum() / len(data)
            if d['missed'] != 0:
                uniq_cov = d['coverage'] / (1 - d['missed'] / d['genome bp'])
                d['unique_mapped_coverage'] = uniq_cov
            else:
                d['unique_mapped_coverage'] = d['coverage']
            d['covered_bp'] = (1 - d['percent missed']/100.0) * d['genome bp']
            d['genome_id'] = genome_id
            d['sample_id'] = sample
            runs[genome_id] = d

        pd.DataFrame(runs).T.to_csv(output[0])

# compute sourmash signature from abundtrim reads
rule smash_abundtrim_wc:
    input:
        metagenome = ancient(outdir + "/abundtrim/{sample}.abundtrim.fq.gz"),
    output:
        sig = outdir + "/sigs/{sample}.abundtrim.sig"
    conda: "env/sourmash.yml"
    params:
        param_str = make_param_str(ksizes=SOURMASH_COMPUTE_KSIZES,
                                   scaled=SOURMASH_COMPUTE_SCALED),
        action = "translate" if SOURMASH_COMPUTE_TYPE == "protein" else "dna"
    shell: """
        sourmash sketch {params.action} -p {params.param_str} \
           {input} -o {output} \
           --name {wildcards.sample}
    """

# configure ipython kernel for papermill
rule set_kernel:
    output:
        touch(f"{outdir}/.kernel.set")
    conda: 'env/papermill.yml'
    shell: """
        python -m ipykernel install --user --name genome_grist
        python -m pip install matplotlib numpy pandas
    """


# papermill -> reporting notebook + html
rule make_notebook_wc:
    input:
        nb = srcdir('../notebooks/report-sample.ipynb'),
        all_csv = f"{outdir}/minimap/depth/{{sample}}.summary.csv",
        depth_csv = f"{outdir}/leftover/depth/{{sample}}.summary.csv",
        gather_csv = f'{outdir}/genbank/{{sample}}.x.genbank.gather.csv',
        genomes_info_csv = f"{outdir}/genbank/{{sample}}.genomes.info.csv",
        kernel_set = rules.set_kernel.output,
    output:
        nb = outdir + f'/reports/report-{{sample}}.ipynb',
        html = outdir + f'/reports/report-{{sample}}.html',
    params:
        cwd = outdir + '/reports/',
        outdir = outdir,
    conda: 'env/papermill.yml'
    shell: """
        papermill {input.nb} - -k genome_grist \
              -p sample_id {wildcards.sample:q} -p render '' -p outdir {outdir:q}\
              --cwd {params.cwd} \
              > {output.nb}
        python -m nbconvert {output.nb} --to html --stdout --no-input \
             --ExecutePreprocessor.kernel_name=genome_grist > {output.html}
    """

# convert mapped reads to leftover reads
# @CTB update subtract-gather to take sample ID as param, instead of inferring it from filename
# @CTB update for intersected/overlapping reads too
rule extract_leftover_reads_wc:
    input:
        csv = f'{outdir}/genbank/{{sample}}.x.genbank.gather.csv',
        mapped = Checkpoint_GatherResults(f"{outdir}/minimap/{{sample}}.x.{{acc}}.mapped.fq.gz"),
    output:
        touch(f"{outdir}/.leftover-reads.{{sample}}")
    conda: "env/sourmash.yml"
    params:
        outdir = outdir,
    shell: """
        python -Werror -m genome_grist.subtract_gather {input.csv} \
            --outdir={params.outdir:q}
    """

# rule for mapping leftover reads to genomes -> BAM
rule map_leftover_reads_wc:
    input:
        all_csv = f"{outdir}/minimap/depth/{{sample}}.summary.csv",
        query = ancient(f"genbank_genomes/{{acc}}_genomic.fna.gz"),
        leftover_reads_flag = f"{outdir}/.leftover-reads.{{sample}}",
    output:
        bam=outdir + "/leftover/{sample}.x.{acc}.bam",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        minimap2 -ax sr -t {threads} {input.query} \
     {outdir}/minimap/{wildcards.sample}.x.{wildcards.acc}.leftover.fq.gz | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

# run sourmash search x genbank and find anything matching.
rule sourmash_prefetch_gather_wc:
    input:
        sig = outdir + "/sigs/{sample}.abundtrim.sig",
        db = SOURMASH_DB_LIST,
    output:
        matches = outdir + "/genbank/{sample}.x.genbank.prefetch.zip",
        csv = outdir + "/genbank/{sample}.x.genbank.prefetch.csv",
    resources:
        mem_mb=int(PREFETCH_MEMORY / 1e3)
    conda: "env/sourmash.yml"
    params:
        ksize = SOURMASH_DB_KSIZE,
        moltype = f"--{SOURMASH_COMPUTE_TYPE.lower()}",
        threshold_bp = SOURMASH_DATABASE_THRESHOLD_BP,
    shell: """
        echo "DB is {input.db}"
        sourmash prefetch {input.sig} {input.db} \
          --save-matches {output.matches} -k {params.ksize} \
          --threshold-bp={params.threshold_bp} {params.moltype} \
          -o {output.csv}
    """

# run sourmash search x genbank and find anything matching.
rule split_query_known_unknown_wc:
    message: """
       Split the sourmash signature for {wildcards.sample} into "known" and "unknown" hashes
    """
    input:
        sig = outdir + "/sigs/{sample}.abundtrim.sig",
        prefetch = outdir + "/genbank/{sample}.x.genbank.prefetch.zip",
    output:
        known = outdir + "/genbank/{sample}.x.genbank.known.sig",
        unknown = outdir + "/genbank/{sample}.x.genbank.unknown.sig",
        report = outdir + "/genbank/{sample}.abundtrim.fq.gz.prefetch.zip.report.txt",
    conda: "env/sourmash.yml"
    params:
        ksize = SOURMASH_DB_KSIZE,
        moltype = SOURMASH_COMPUTE_TYPE,
    shell: """
        python -W error -m genome_grist.gather_split {input.sig} {input.prefetch} \
          {output.known} {output.unknown}  \
          -k {params.ksize} --moltype {params.moltype} --report {output.report}
    """

# run sourmash gather with known hashes on prefetched matches.
rule sourmash_gather_wc:
    input:
        known = outdir + "/genbank/{sample}.x.genbank.known.sig",
        db = outdir + "/genbank/{sample}.x.genbank.prefetch.zip",
    output:
        csv = outdir + "/genbank/{sample}.x.genbank.gather.csv",
        matches = outdir + "/genbank/{sample}.x.genbank.matches.sig",
        out = outdir + "/genbank/{sample}.x.genbank.gather.out",
    conda: "env/sourmash.yml"
    params:
        threshold_bp = SOURMASH_DATABASE_THRESHOLD_BP,
    shell: """
        sourmash gather {input.known} {input.db} -o {output.csv} \
          --threshold-bp {params.threshold_bp} \
          --save-matches {output.matches} > {output.out}
    """

# download genbank genome details; make an info.csv file for entry.
rule make_genome_info_csv:
    output:
        csvfile = 'genbank_genomes/{acc}.info.csv'
    shell: """
        python -Werror -m genome_grist.genbank_genomes {wildcards.acc} \
            --output {output.csvfile}
    """

# combined info.csv
rule make_combined_info_csv:
    input:
        Checkpoint_GatherResults('genbank_genomes/{acc}.info.csv')
    output:
        genomes_info_csv = f"{outdir}/genbank/{{sample}}.genomes.info.csv",
    shell: """
        python -Werror -m genome_grist.combine_csvs {input} > {output}
    """

# download actual genomes!
rule download_matching_genome_wc:
     input:
         csvfile = ancient('genbank_genomes/{acc}.info.csv')
     output:
         genome = "genbank_genomes/{acc}_genomic.fna.gz"
     run:
         with open(input.csvfile, 'rt') as infp:
             r = csv.DictReader(infp)
             rows = list(r)
             assert len(rows) == 1
             row = rows[0]
             acc = row['acc']
             assert wildcards.acc.startswith(acc)
             url = row['genome_url']
             name = row['ncbi_tax_name']

             print(f"downloading genome for acc {acc}/{name} from NCBI...")
             with open(output.genome, 'wb') as outfp:
                 with urllib.request.urlopen(url) as response:
                     content = response.read()
                     outfp.write(content)
                     print(f"...wrote {len(content)} bytes to {output.genome}")


# summarize_reads_info
rule summarize_reads_info_wc:
    input:
        kmers = outdir + "/abundtrim/{sample}.abundtrim.fq.gz.kmer-report.txt",
        reads = outdir + "/abundtrim/{sample}.abundtrim.fq.gz.reads-report.txt",
        sigs = outdir + "/genbank/{sample}.abundtrim.fq.gz.prefetch.zip.report.txt",
    output:
        outdir + '/{sample}.info.yaml',
    run:
        d = {}
        with open(str(input.kmers), 'rt') as fp:
            kmers = fp.readlines()[3].strip()
            kmers = kmers.split()[-1]
            d['kmers'] = int(kmers)

        with open(str(input.reads), 'rt') as fp:
            r = csv.DictReader(fp)
            row = next(iter(r))
            d['n_reads'] = int(row['n_reads'])
            d['n_bases'] = int(row['n_bases'])

        with open(str(input.sigs), 'rt') as fp:
            r = csv.DictReader(fp)
            row = next(iter(r))
            d['total_hashes'] = int(row['total_hashes'])
            d['known_hashes'] = int(row['known_hashes'])
            d['unknown_hashes'] = int(row['unknown_hashes'])

        d['sample'] = wildcards.sample

        with open(str(output), 'wt') as fp:
            import yaml
            yaml.dump(d, fp)


# create a spacegraphcats config file
rule create_sgc_conf_wc:
    input:
        csv = outdir + "/genbank/{sample}.x.genbank.gather.csv",
        queries = Checkpoint_GatherResults("genbank_genomes/{acc}_genomic.fna.gz"),
    output:
        conf = outdir + "/sgc/{sample}.conf"
    run:
        query_list = "\n- ".join(input.queries)
        with open(output.conf, 'wt') as fp:
           print(f"""\
catlas_base: {wildcards.sample}
input_sequences:
- {outdir}/abundtrim/{wildcards.sample}.abundtrim.fq.gz
ksize: 31
radius: 1
search:
- {query_list}
""", file=fp)
