# see README or genome_grist/__main__.py for top-level targets.
import snakemake
import os, csv, tempfile, shutil, sys

if config.get('sample'):
    print(f"ERROR: this is an old config file. Please use 'samples' instead of 'sample'.",
          file=sys.stderr)
    sys.exit(-1)

SAMPLES=config['samples']
print(f'samples: {SAMPLES}', file=sys.stderr)
assert isinstance(SAMPLES, list), "config 'samples' must be a list."

fail = False
for sample in SAMPLES:
    if '.' in sample:
        print(f"sample name '{sample}' contains a period; please remove",
              file=sys.stderr)
        fail = True

if fail:
    sys.exit(-1)

outdir = config.get('outdir', 'outputs/')
outdir = outdir.rstrip('/')
print('outdir:', outdir, file=sys.stderr)

base_tempdir = None
try_temp_locations = config.get('tempdir', [])
for temp_loc in try_temp_locations:
    try:
        base_tempdir = tempfile.mkdtemp(dir=temp_loc)
    except FileNotFoundError:
        pass

if not base_tempdir:
    print(f"Could not create a temporary directory in any of {try_temp_locations}", file=sys.stderr)
    print("Please set 'tempdir' in the config.", file=sys.stderr)
    sys.exit(-1)

ABUNDTRIM_MEMORY = float(config.get('metagenome_trim_memory', '1e9'))

GENBANK_CACHE = config.get('genbank_cache', './genbank_cache/')
GENBANK_CACHE = os.path.normpath(GENBANK_CACHE)

### collect databases

if config.get('sourmash_database_glob_pattern'):
    print(f"ERROR: this is an old config file. Please provide a list of 'genbank_databases' instead of a 'sourmash_database_glob_pattern'",
          file=sys.stderr)
    sys.exit(-1)

sourmash_databases = config.get('sourmash_databases', [])
local_databases_info = config.get('local_databases_info', [])

SOURMASH_DB_LIST = sourmash_databases
if not SOURMASH_DB_LIST:
    print(f"ERROR: no sourmash_databases specified in config file!",
          file=sys.stderr)
    sys.exit(-1)

SOURMASH_DB_KSIZE = config.get('sourmash_database_ksize', ['31'])
SOURMASH_DATABASE_THRESHOLD_BP = config.get('sourmash_database_threshold_bp',
                                            1e5)
SOURMASH_COMPUTE_KSIZES = config.get('sourmash_compute_ksizes',
                                     ['21', '31',' 51'])
SOURMASH_COMPUTE_SCALED = config.get('sourmash_scaled', '1000')
SOURMASH_COMPUTE_TYPE = config.get('sourmash_sigtype', 'DNA')
assert SOURMASH_COMPUTE_TYPE in ('DNA', 'protein')

PREFETCH_MEMORY = float(config.get('prefetch_memory', '100e9'))

if config.get('database_taxonomy'):
    print(f"ERROR: this is an old config file. Please use 'taxonomies' instead of 'database_taxonomy'.",
          file=sys.stderr)
    sys.exit(-1)

TAXONOMY_DB = config.get('taxonomies', ["NO_TAXONOMIES_SPECIFIED_IN_CONFIG"])
if not isinstance(TAXONOMY_DB, list):
    print(f"ERROR: 'taxonomies' should be a list.", file=sys.stderr)
    sys.exit(-1)

PICKLIST = config.get('picklist', '')

# make a `sourmash sketch` -p param string.
def make_param_str(ksizes, scaled):
    ks = [ f'k={k}' for k in ksizes ]
    ks = ",".join(ks)
    return f"{ks},scaled={scaled},abund"

###

# utility function
def load_csv(filename):
    with open(filename, "r", newline="") as fp:
        r = csv.DictReader(fp)
        for row in r:
            yield row

###

# mark top-level rules:
_toplevel_rules = []
def toplevel(fn):
    assert fn.__name__.startswith('__')
    _toplevel_rules.append(fn.__name__[2:])
    return fn

onsuccess:
    shutil.rmtree(base_tempdir, ignore_errors=True)

onerror:
    shutil.rmtree(base_tempdir, ignore_errors=True)

wildcard_constraints:
    size="\d+",
    sample='[a-zA-Z0-9._]+'                   # should be everything but /

@toplevel
rule print_rules:
    run:
        print("\nTop level rules are: \n", file=sys.stderr)
        print("* " + "\n* ".join(_toplevel_rules), file=sys.stderr)
        print("\nPlease see documentation for details.\n\n",
              file=sys.stderr)

@toplevel
rule clean_gather:
    # the '|| true' here makes this command succeed even when dirs do not exist
    # (adding -fr to rm seemed too dangerous!)
    shell: """
        (rm -r {outdir}/{{gather,genomes,mapping,leftover,reports}}/ \
            {outdir}/.kernel.set || true)
    """

@toplevel
rule download_reads:
    input:
        expand(f"{outdir}/raw/{{sample}}_1.fastq.gz", sample=SAMPLES),
        expand(f"{outdir}/raw/{{sample}}_2.fastq.gz", sample=SAMPLES),
        expand(f"{outdir}/raw/{{sample}}.raw.sig", sample=SAMPLES),

@toplevel
rule trim_reads:
    input:
        url_file = expand(f"{outdir}/abundtrim/{{sample}}.abundtrim.fq.gz",
                          sample=SAMPLES)

@toplevel
rule estimate_distinct_kmers:
    input:
        url_file = expand(f"{outdir}/abundtrim/{{sample}}.abundtrim.fq.gz.kmer-report.txt",
                          sample=SAMPLES)

@toplevel
rule count_trimmed_reads:
    input:
        url_file = expand(f"{outdir}/abundtrim/{{sample}}.abundtrim.fq.gz.reads-report.txt",
                          sample=SAMPLES)

@toplevel
rule smash_reads:
    input:
        url_file = expand(f"{outdir}/sigs/{{sample}}.abundtrim.sig",
                          sample=SAMPLES)

@toplevel
rule prefetch_reads:
    input:
        expand(outdir + "/gather/{sample}.prefetch.csv",
               sample=SAMPLES)

@toplevel
rule summarize_sample_info:
    input:
        expand(outdir + '/{sample}.info.yaml', sample=SAMPLES)

@toplevel
checkpoint gather_reads:
    input:
        expand(f'{outdir}/gather/{{sample}}.gather.csv',
               sample=SAMPLES)

checkpoint gather_reads_wc:
    input:
        gather_csv = f'{outdir}/gather/{{sample}}.gather.csv'
    output:
        touch(f"{outdir}/gather/.gather.{{sample}}")   # checkpoints need an output ;)


_gather_csv_cache = {}
class Checkpoint_GatherResults:
    """Given a pattern containing {ident} and {sample}, this class
    will generate the list of {ident} for that {sample}'s gather results.

    Alternatively, you can omit {sample} from the pattern and include the
    list of sample names in the second argument to the constructor.
    """
    def __init__(self, pattern, samples=None):
        self.pattern = pattern
        self.samples = samples

    def get_genome_idents(self, sample):
        gather_csv = f'{outdir}/gather/{sample}.gather.csv'
        assert os.path.exists(gather_csv), "gather output does not exist!?"

        genome_idents = []
        with open(gather_csv, 'rt') as fp:
           r = csv.DictReader(fp)
           for row in r:
               ident = row['name'].split(' ')[0]
               genome_idents.append(ident)
        print(f'loaded {len(genome_idents)} identifiers from {gather_csv}.',
              file=sys.stderr)

        return genome_idents

    def __call__(self, w):
        # get 'sample' from wildcards?
        if self.samples is None:
            return self.do_sample(w)
        else:
            assert not hasattr(w, 'sample'), "if 'samples' provided to constructor, cannot also be in rule inputs"

            ret = []
            for sample in self.samples:
                d = dict(sample=sample)
                w = snakemake.io.Wildcards(fromdict=d)

                x = self.do_sample(w)
                ret.extend(x)

            return ret

    def do_sample(self, w):
        # wait for the results of 'gather_reads_wc'; this will trigger
        # exception until that rule has been run.
        checkpoints.gather_reads_wc.get(**w)

        # parse hitlist_genomes,
        genome_idents = self.get_genome_idents(w.sample)

        p = expand(self.pattern, ident=genome_idents, **w)

        return p


class ListGatherGenomes(Checkpoint_GatherResults):
    """Provide list of the source genome files for either local or
    genbank databases, depending on the source of the genome info.

    This is used to get the filenames for the various genome files; the
    assumption is that snakemake already knows how to generate those :).

    Note: a key thing here is that the filenames themselves are correct,
    so we are not renaming any files, merely copying them with their
    existing names into the output subdirectory 'genomes/'.
    """
    def __init__(self, samples=None):
        self.samples = samples

    def _load_local_database_info(self):
        # get list of local genomes first...
        local_info = {}
        for filename in local_databases_info:
            for row in load_csv(filename):
                ident = row['ident']

                genome_dir = os.path.dirname(row['genome_filename'])
                row['genome_filename'] = os.path.normpath(row['genome_filename'])

                genome_dir = os.path.dirname(row['genome_filename'])
                info_filename = f'{ident}.info.csv'
                info_filename = os.path.join(genome_dir, info_filename)
                
                row['info_filename'] = os.path.normpath(info_filename)

                local_info[ident] = row

        if local_info:
            print(f"Loaded info on {len(local_info)} local genomes.")
        return local_info

    def do_sample(self, w):
        # wait for the results of 'gather_reads_wc'; this will trigger
        # exception until that rule has been run.
        checkpoints.gather_reads_wc.get(**w)

        sample = w.sample

        local_info = self._load_local_database_info()
        genome_filenames = []

        gather_csv = f'{outdir}/gather/{sample}.gather.csv'
        assert os.path.exists(gather_csv), f"gather output does not exist for {sample}!?"

        for row in load_csv(gather_csv):
            ident = row['name'].split(' ')[0]

            # if in local information, use that as genome source.
            if ident in local_info:
                info = local_info[ident]
                genome_filenames.append(info['genome_filename'])
                genome_filenames.append(info['info_filename'])

            # genbank: point at genbank_genomes
            else:
                genome_filenames.append(f'{GENBANK_CACHE}/{ident}_genomic.fna.gz')
                genome_filenames.append(f'{GENBANK_CACHE}/{ident}.info.csv')

        return genome_filenames

class Checkpoint_GenomeFiles(Checkpoint_GatherResults):
    def do_sample(self, w):
        checkpoints.copy_sample_genomes_to_output_wc.get(**w)

        # parse hitlist_genomes,
        genome_idents = self.get_genome_idents(w.sample)

        if 'ident' in dict(w):
            p = expand(self.pattern, **w)
        else:
            p = expand(self.pattern, ident=genome_idents, **w)

        return p


@toplevel
rule gather_to_tax:
    input:
        expand(f'{outdir}/gather/{{sample}}.gather.with-lineages.csv',
               sample=SAMPLES)
@toplevel
rule summarize_gather:
    input:
        expand(f'{outdir}/reports/report-gather-{{sample}}.html',
               sample=SAMPLES)

@toplevel
rule summarize_tax:
    input:
        expand(f'{outdir}/reports/report-taxonomy-{{sample}}.html',
               sample=SAMPLES),
        expand(f'{outdir}/gather/{{sample}}.gather.with-lineages.csv',
               sample=SAMPLES)

@toplevel
rule combine_genome_info:
    input:
        expand(f"{outdir}/gather/{{sample}}.genomes.info.csv",
               sample=SAMPLES)

@toplevel
rule download_genbank_genomes:
    input:
        Checkpoint_GatherResults(f"{GENBANK_CACHE}/{{ident}}_genomic.fna.gz",
                                 samples=SAMPLES)

@toplevel
rule retrieve_genomes:
    input:
        expand(f"{outdir}/genomes/.genomes.{{sample}}", sample=SAMPLES)

@toplevel
rule map_reads:
    input:
        expand(f"{outdir}/mapping/{{sample}}.summary.csv", sample=SAMPLES),
        expand(f"{outdir}/leftover/{{sample}}.summary.csv", sample=SAMPLES)

@toplevel
rule build_consensus:
    input:
        Checkpoint_GatherResults(outdir + f"/mapping/{{sample}}.x.{{ident}}.consensus.fa.gz"),
        Checkpoint_GatherResults(outdir + f"/leftover/{{sample}}.x.{{ident}}.consensus.fa.gz"),

@toplevel
rule summarize_mapping:
    input:
        expand(f'{outdir}/reports/report-mapping-{{sample}}.html',
               sample=SAMPLES),
        expand(f'{outdir}/reports/report-gather-{{sample}}.html',
               sample=SAMPLES)

@toplevel
rule make_sgc_conf:
    input:
        expand(f"{outdir}/sgc/{{sample}}.conf", sample=SAMPLES)

# print out the configuration
@toplevel
rule showconf:
    run:
        import yaml
        print('# full aggregated configuration:')
        print(yaml.dump(config).strip())
        print('# END')

# check config files only
@toplevel
rule check:
    run:
        pass

@toplevel
rule zip:
    shell: """
        rm -f transfer.zip
        zip -r transfer.zip {outdir}/leftover/*.summary.csv \
                {outdir}/mapping/*.summary.csv {outdir}/*.yaml \
                {outdir}/gather/*.csv {outdir}/reports/
    """


# download SRA IDs.
rule download_sra_wc:
    output:
        r1  = protected(outdir + "/raw/{sample}_1.fastq.gz"),
        r2  = protected(outdir + "/raw/{sample}_2.fastq.gz"),
        unp = protected(outdir + "/raw/{sample}_unpaired.fastq.gz"),
        # temporary output
        temp_dir = temp(directory(f"{base_tempdir}/{{sample}}.d")),
        temp_r1 =  f"{base_tempdir}/{{sample}}.d/{{sample}}_1.fastq",
        temp_r2 =  f"{base_tempdir}/{{sample}}.d/{{sample}}_2.fastq",
        temp_unp = f"{base_tempdir}/{{sample}}.d/{{sample}}.fastq",
    threads: 6
    conda: "env/sra.yml"
    resources:
        mem_mb=40000,
    shell: '''
        echo tmp directory: {output.temp_dir}
        echo running fasterq-dump for {wildcards.sample}

        fasterq-dump {wildcards.sample} -O {output.temp_dir} \
        -t {output.temp_dir} -e {threads} -p --split-files
        ls -h {output.temp_dir}

        # make unpaired file if needed
        if [ -f {output.temp_r1} -a -f {output.temp_r2} -a \! -f {output.temp_unp} ];
          then
            echo "no unpaired; creating empty unpaired file {output.unp} for simplicity"
            touch {output.temp_unp}
          # make r1, r2 files if needed
        elif [ -f {output.temp_unp} -a \! -f {output.temp_r1} -a \! -f {output.temp_r2} ];
          then
            echo "unpaired file found; creating empty r1 ({output.temp_r1}) and r2 ({output.temp_r2}) files for simplicity"
            touch {output.temp_r1}
            touch {output.temp_r2}
        fi

        # now process the files and move to a permanent location
        echo processing R1...
        seqtk seq -C {output.temp_r1} | \
            perl -ne 's/\.([12])$/\/$1/; print $_' | \
            gzip -c > {output.r1} &

        echo processing R2...
        seqtk seq -C {output.temp_r2} | \
            perl -ne 's/\.([12])$/\/$1/; print $_' | \
            gzip -c > {output.r2} &

        echo processing unpaired...
        seqtk seq -C {output.temp_unp} | \
            perl -ne 's/\.([12])$/\/$1/; print $_' | \
            gzip -c > {output.unp} &
        wait
        echo finished downloading raw reads
        '''

# compute sourmash signature from raw reads
rule smash_raw_reads_wc:
    input:
        r1 = ancient(outdir + "/raw/{sample}_1.fastq.gz"),
        r2 = ancient(outdir + "/raw/{sample}_2.fastq.gz"),
        unp = ancient(outdir + "/raw/{sample}_unpaired.fastq.gz"),
    output:
        sig = outdir + "/raw/{sample}.raw.sig"
    conda: "env/sourmash.yml"
    params:
        param_str = make_param_str(ksizes=SOURMASH_COMPUTE_KSIZES,
                                   scaled=SOURMASH_COMPUTE_SCALED),
        action = "translate" if SOURMASH_COMPUTE_TYPE == "protein" else "dna"
    shell: """
        sourmash sketch {params.action} -p {params.param_str} \
           {input} -o {output} --name 'rawreads:{wildcards.sample}'
    """

# adapter trimming
rule trim_adapters_wc:
    input:
        r1 = ancient(outdir + "/raw/{sample}_1.fastq.gz"),
        r2 = ancient(outdir + "/raw/{sample}_2.fastq.gz"),
    output:
        interleaved = protected(outdir + '/trim/{sample}.trim.fq.gz'),
        json=outdir + "/trim/{sample}.trim.json",
        html=outdir + "/trim/{sample}.trim.html",
    conda: 'env/trim.yml'
    threads: 4
    resources:
        mem_mb=5000,
        runtime_min=600,
    shadow: "shallow"
    shell: """
        fastp --in1 {input.r1} --in2 {input.r2} \
             --detect_adapter_for_pe  --qualified_quality_phred 4 \
             --length_required 25 --correction --thread {threads} \
             --json {output.json} --html {output.html} \
             --low_complexity_filter --stdout | gzip -9 > {output.interleaved}
    """

# adapter trimming for the singleton reads
rule trim_unpaired_adapters_wc:
    input:
        unp = ancient(outdir + "/raw/{sample}_unpaired.fastq.gz"),
    output:
        unp = protected(outdir + '/trim/{sample}_unpaired.trim.fq.gz'),
        json = protected(outdir + '/trim/{sample}_unpaired.trim.json'),
        html = protected(outdir + '/trim/{sample}_unpaired.trim.html'),
    threads: 4
    resources:
        mem_mb=5000,
        runtime_min=600,
    shadow: "shallow"
    conda: 'env/trim.yml'
    shell: """
        fastp --in1 {input.unp} --out1 {output.unp} \
            --detect_adapter_for_se  --qualified_quality_phred 4 \
            --low_complexity_filter --thread {threads} \
            --length_required 25 --correction \
            --json {output.json} --html {output.html}
    """

# k-mer abundance trimming
rule kmer_trim_reads_wc:
    input: 
        interleaved = ancient(outdir + '/trim/{sample}.trim.fq.gz'),
    output:
        protected(outdir + "/abundtrim/{sample}.abundtrim.fq.gz")
    conda: 'env/trim.yml'
    resources:
        mem_mb = int(ABUNDTRIM_MEMORY / 1e6),
    params:
        mem = ABUNDTRIM_MEMORY,
    shell: """
            trim-low-abund.py -C 3 -Z 18 -M {params.mem} -V \
            {input.interleaved} -o {output} --gzip
    """

# count k-mers
rule estimate_distinct_kmers_wc:
    message: """
        Count distinct k-mers for {wildcards.sample} using 'unique-kmers.py' from the khmer package.
    """
    conda: 'env/trim.yml'
    input:
        outdir + "/abundtrim/{sample}.abundtrim.fq.gz"
    output:
        report = outdir + "/abundtrim/{sample}.abundtrim.fq.gz.kmer-report.txt",
    params:
        ksize = SOURMASH_DB_KSIZE,
    shell: """
        unique-kmers.py {input} -k {params.ksize} -R {output.report}
    """

# count reads and bases
rule count_trimmed_reads_wc:
    message: """
        Count reads & bp in trimmed file for {wildcards.sample}.
    """
    input:
        outdir + "/abundtrim/{sample}.abundtrim.fq.gz"
    output:
        report = outdir + "/abundtrim/{sample}.abundtrim.fq.gz.reads-report.txt",
    # from Matt Bashton, in https://bioinformatics.stackexchange.com/questions/935/fast-way-to-count-number-of-reads-and-number-of-bases-in-a-fastq-file
    shell: """
        gzip -dc {input} |
             awk 'NR%4==2{{c++; l+=length($0)}}
                  END{{
                        print "n_reads,n_bases"
                        print c","l
                      }}' > {output.report}
    """

# map abundtrim reads and produce a bam
rule minimap_wc:
    input:
        query = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz"),
        metagenome = outdir + "/abundtrim/{sample}.abundtrim.fq.gz",
    output:
        bam = outdir + "/mapping/{sample}.x.{ident}.bam",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        minimap2 -ax sr -t {threads} {input.query} {input.metagenome} | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

# extract FASTQ from BAM
rule bam_to_fastq_wc:
    input:
        bam = outdir + "/mapping/{bam}.bam",
    output:
        mapped = outdir + "/mapping/{bam}.mapped.fq.gz",
    conda: "env/minimap2.yml"
    shell: """
        samtools bam2fq {input.bam} | gzip > {output.mapped}
    """

# get per-base depth information from BAM
rule bam_to_depth_wc:
    input:
        bam = outdir + "/{dir}/{bam}.bam",
    output:
        depth = outdir + "/{dir}/{bam}.depth.txt",
    conda: "env/minimap2.yml"
    shell: """
        samtools depth -aa {input.bam} > {output.depth}
    """

# wild card rule for getting _covered_ regions from BAM
rule bam_covered_regions_wc:
    input:
        bam = outdir + "/{dir}/{bam}.bam",
    output:
        regions = outdir + "/{dir}/{bam}.regions.bed",
    conda: "env/covtobed.yml"
    shell: """
        covtobed {input.bam} -l 100 -m 1 | \
            bedtools merge -d 5 -c 4 -o mean > {output.regions}
    """

# calculating SNPs/etc.
rule mpileup_wc:
    input:
        query = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz"),
        bam = outdir + "/{dir}/{sample}.x.{ident}.bam",
    output:
        bcf = outdir + "/{dir}/{sample}.x.{ident}.bcf",
        vcf = outdir + "/{dir}/{sample}.x.{ident}.vcf.gz",
        vcfi = outdir + "/{dir}/{sample}.x.{ident}.vcf.gz.csi",
    conda: "env/bcftools.yml"
    shell: """
        genomefile=$(mktemp -t grist.genome.XXXXXXX)
        gunzip -c {input.query} > $genomefile
        bcftools mpileup -Ou -f $genomefile {input.bam} | bcftools call -mv -Ob -o {output.bcf}
        rm $genomefile
        bcftools view {output.bcf} | bgzip > {output.vcf}
        bcftools index {output.vcf}
    """

# build new consensus
rule build_new_consensus_wc:
    input:
        vcf = outdir + "/{dir}/{sample}.x.{ident}.vcf.gz",
        query = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz"),
        regions = outdir + "/{dir}/{sample}.x.{ident}.regions.bed",
    output:
        mask = outdir + "/{dir}/{sample}.x.{ident}.mask.bed",
        genomefile = outdir + "/{dir}/{sample}.x.{ident}.fna.gz.sizes",
        consensus = outdir + "/{dir}/{sample}.x.{ident}.consensus.fa.gz",
    conda: "env/bcftools.yml"
    shell: """
        genomefile=$(mktemp -t grist.genome.XXXXXXX)
        gunzip -c {input.query} > $genomefile
        samtools faidx $genomefile
        cut -f1,2 ${{genomefile}}.fai > {output.genomefile}
        bedtools complement -i {input.regions} -g {output.genomefile} > {output.mask}
        bcftools consensus -f $genomefile {input.vcf} -m {output.mask} | \
            gzip > {output.consensus}
        rm $genomefile
    """

# summarize depth into a CSV
rule summarize_samtools_depth_wc:
    input:
        Checkpoint_GatherResults(outdir + f"/{{dir}}/{{sample}}.x.{{ident}}.depth.txt")
    output:
        csv = f"{outdir}/{{dir}}/{{sample}}.summary.csv"
    shell: """
        python -m genome_grist.summarize_mapping {wildcards.sample} \
             {input} -o {output.csv}
    """

# compute sourmash signature from abundtrim reads
rule smash_abundtrim_wc:
    input:
        metagenome = ancient(outdir + "/abundtrim/{sample}.abundtrim.fq.gz"),
    output:
        sig = outdir + "/sigs/{sample}.abundtrim.sig"
    conda: "env/sourmash.yml"
    params:
        param_str = make_param_str(ksizes=SOURMASH_COMPUTE_KSIZES,
                                   scaled=SOURMASH_COMPUTE_SCALED),
        action = "translate" if SOURMASH_COMPUTE_TYPE == "protein" else "dna"
    shell: """
        sourmash sketch {params.action} -p {params.param_str} \
           {input} -o {output} \
           --name {wildcards.sample}
    """

# configure ipython kernel for papermill
rule set_kernel:
    input:
        srcdir('env/papermill.yml')
    output:
        touch(f"{outdir}/.kernel.set")
    conda: 'env/papermill.yml'
    shell: """
        python -m ipykernel install --user --name genome_grist
    """


# papermill -> gather reporting notebook + html
# CTB: it's not clear why 'ancient' is needed here, but it prevents this rule
# from being run over and over!
rule make_gather_notebook_wc:
    input:
        nb = srcdir('../notebooks/report-gather.ipynb'),
        gather_csv = f'{outdir}/gather/{{sample}}.gather.csv',
        genomes_info_csv = ancient(f"{outdir}/gather/{{sample}}.genomes.info.csv"),
        kernel_set = rules.set_kernel.output,
    output:
        nb = outdir + f'/reports/report-gather-{{sample}}.ipynb',
        html = outdir + f'/reports/report-gather-{{sample}}.html',
    params:
        cwd = outdir + '/reports/',
        outdir = outdir,
    conda: 'env/papermill.yml'
    shell: """
        papermill {input.nb} {output.nb} -k genome_grist \
              -p sample_id {wildcards.sample:q} -p render '' -p outdir {outdir:q}\
              --cwd {params.cwd}
        python -m nbconvert {output.nb} --to html --stdout --no-input \
             --ExecutePreprocessor.kernel_name=genome_grist > {output.html}
    """

# papermill -> taxonomy reporting notebook + html
# CTB: it's not clear why 'ancient' is needed here, but it prevents this rule
# from being run over and over!
rule make_taxonomy_notebook_wc:
    input:
        nb = srcdir('../notebooks/report-taxonomy.ipynb'),
        taxcsv = ancient(f'{outdir}/gather/{{sample}}.gather.with-lineages.csv'),
        kernel_set = rules.set_kernel.output,
    output:
        nb = outdir + f'/reports/report-taxonomy-{{sample}}.ipynb',
        html = outdir + f'/reports/report-taxonomy-{{sample}}.html',
    params:
        cwd = outdir + '/reports/',
        outdir = outdir,
    conda: 'env/papermill.yml'
    shell: """
        papermill {input.nb} {output.nb} -k genome_grist \
              -p sample_id {wildcards.sample:q} -p render '' -p outdir {outdir:q}\
              --cwd {params.cwd}
        python -m nbconvert {output.nb} --to html --stdout --no-input \
             --ExecutePreprocessor.kernel_name=genome_grist > {output.html}
    """

# papermill -> reporting notebook + html
# CTB: it's not clear why 'ancient' is needed here, but it prevents this rule
# from being run over and over!
rule make_mapping_notebook_wc:
    input:
        nb = srcdir('../notebooks/report-mapping.ipynb'),
        all_csv = ancient(f"{outdir}/mapping/{{sample}}.summary.csv"),
        depth_csv = ancient(f"{outdir}/leftover/{{sample}}.summary.csv"),
        gather_csv = f'{outdir}/gather/{{sample}}.gather.csv',
        genomes_info_csv = ancient(f"{outdir}/gather/{{sample}}.genomes.info.csv"),
        kernel_set = rules.set_kernel.output,
    output:
        nb = outdir + f'/reports/report-mapping-{{sample}}.ipynb',
        html = outdir + f'/reports/report-mapping-{{sample}}.html',
    params:
        cwd = outdir + '/reports/',
        outdir = outdir,
    conda: 'env/papermill.yml'
    shell: """
        papermill {input.nb} {output.nb} -k genome_grist \
              -p sample_id {wildcards.sample:q} -p render '' \
              -p outdir {outdir:q} --cwd {params.cwd}
        python -m nbconvert {output.nb} --to html --stdout --no-input \
             --ExecutePreprocessor.kernel_name=genome_grist > {output.html}
    """

# convert mapped reads to leftover reads
rule extract_leftover_reads_wc:
    input:
        csv = f'{outdir}/gather/{{sample}}.gather.csv',
        mapped = Checkpoint_GatherResults(f"{outdir}/mapping/{{sample}}.x.{{ident}}.mapped.fq.gz"),
    output:
        touch(f"{outdir}/leftover/.leftover.{{sample}}")
    conda: "env/sourmash.yml"
    params:
        outdir = outdir,
    shell: """
        python -Werror -m genome_grist.subtract_gather \
            {wildcards.sample:q} {input.csv} --outdir={params.outdir:q}
    """

# rule for mapping leftover reads to genomes -> BAM
rule map_leftover_reads_wc:
    input:
        all_csv = f"{outdir}/mapping/{{sample}}.summary.csv",
        query = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz"),
        leftover_reads_flag = f"{outdir}/leftover/.leftover.{{sample}}",
    output:
        bam=outdir + "/leftover/{sample}.x.{ident}.bam",
    conda: "env/minimap2.yml"
    threads: 4
    shell: """
        minimap2 -ax sr -t {threads} {input.query} \
     {outdir}/mapping/{wildcards.sample}.x.{wildcards.ident}.leftover.fq.gz | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

# run sourmash search x genbank and find anything matching.
rule sourmash_prefetch_wc:
    message: """
       Find all potentially relevant database matches for {wildcards.sample}
    """
    input:
        sig = outdir + "/sigs/{sample}.abundtrim.sig",
        db = SOURMASH_DB_LIST,
    output:
        csv = outdir + "/gather/{sample}.prefetch.csv",
        known = outdir + "/gather/{sample}.known.sig",
        unknown = outdir + "/gather/{sample}.unknown.sig",
    resources:
        mem_mb=int(PREFETCH_MEMORY / 1e3)
    conda: "env/sourmash.yml"
    params:
        ksize = SOURMASH_DB_KSIZE,
        moltype = f"--{SOURMASH_COMPUTE_TYPE.lower()}",
        threshold_bp = SOURMASH_DATABASE_THRESHOLD_BP,
        picklist = f"--picklist {PICKLIST}" if PICKLIST else ""
    shell: """
        echo "DB is {input.db}"
        sourmash prefetch {input.sig} {input.db} -k {params.ksize} \
          --threshold-bp={params.threshold_bp} {params.moltype} \
          --save-matching-hashes={output.known} \
          --save-unmatched-hashes={output.unknown} \
          {params.picklist} -o {output.csv}
    """

# report on known and unknown hashes
rule report_query_known_unknown_wc:
    message: """
       Report on "known" and "unknown" hashes for {wildcards.sample}
    """
    input:
        query = outdir + "/sigs/{sample}.abundtrim.sig",
        known = outdir + "/gather/{sample}.known.sig",
        unknown = outdir + "/gather/{sample}.unknown.sig",
    output:
        report = outdir + "/gather/{sample}.prefetch.report.txt",
    conda: "env/sourmash.yml"
    params:
        ksize = SOURMASH_DB_KSIZE,
        moltype = SOURMASH_COMPUTE_TYPE,
    shell: """
         python -W error -m genome_grist.summarize_gather_hashes \
          {input.query} {input.known} {input.unknown}  \
          -k {params.ksize} --moltype {params.moltype} --report {output.report}
    """

# run sourmash gather with known hashes on prefetched matches.
rule sourmash_gather_wc:
    message: """
       Run gather for {wildcards.sample}
    """
    input:
        known = outdir + "/gather/{sample}.known.sig",
        picklist = outdir + "/gather/{sample}.prefetch.csv",
        db = SOURMASH_DB_LIST,
    output:
        csv = outdir + "/gather/{sample}.gather.csv",
        matches = outdir + "/gather/{sample}.matches.sig",
        out = outdir + "/gather/{sample}.gather.out",
    conda: "env/sourmash.yml"
    params:
        threshold_bp = SOURMASH_DATABASE_THRESHOLD_BP,
    shell: """
        sourmash gather {input.known} {input.db} -o {output.csv} \
          --threshold-bp {params.threshold_bp} \
          --picklist {input.picklist}::prefetch \
          --save-matches {output.matches} > {output.out}
    """

# run sourmash tax annotate on the gather results
rule summarize_tax_wc:
    input:
        gather_csv = f"{outdir}/gather/{{sample}}.gather.csv",
        tax_csv = TAXONOMY_DB,
    output:
        f"{outdir}/gather/{{sample}}.gather.with-lineages.csv",
    conda: "env/sourmash.yml"
    params:
        o_param = f"{outdir}/gather/"
    shell: """
        sourmash tax annotate -g {input.gather_csv} -t {input.tax_csv} \
            -o {params.o_param} --fail-on-missing-taxonomy
    """


# download genbank genome details; make an info.csv file for entry.
rule make_genbank_info_csv:
    output:
        csvfile = f'{GENBANK_CACHE}/{{ident}}.info.csv'
    shell: """
        python -Werror -m genome_grist.genbank_genomes {wildcards.ident} \
            --output {output.csvfile}
    """

# copy genbank + local genomes for this sample into output dir
checkpoint copy_sample_genomes_to_output_wc:
    input:
        # note: a key thing here is that the filenames themselves are correct,
        # so we are simply copying from (multiple) directories into one.
        # this is why the genome filenames need to be {acc}_genomic.fna.gz.
        genomes = ListGatherGenomes(),
    output:
        touch(f"{outdir}/genomes/.genomes.{{sample}}")
    shell: """
        mkdir -p {outdir}/genomes/
        cp {input} {outdir}/genomes/
    """

# combined info.csv per sample
rule make_combined_info_csv_wc:
    input:
        csvs = Checkpoint_GenomeFiles(f'{outdir}/genomes/{{ident}}.info.csv'),
    output:
        genomes_info_csv = f"{outdir}/gather/{{sample}}.genomes.info.csv",
    shell: """
        python -Werror -m genome_grist.combine_csvs \
             --fields ident,display_name \
             {input.csvs} > {output}
    """

# download actual genomes from genbank!
rule download_matching_genome_wc:
    input:
        csvfile = ancient(f'{GENBANK_CACHE}/{{ident}}.info.csv')
    output:
        genome = f"{GENBANK_CACHE}/{{ident}}_genomic.fna.gz"
    run:
        with open(input.csvfile, 'rt') as infp:
            r = csv.DictReader(infp)
            rows = list(r)
            assert len(rows) == 1
            row = rows[0]
            ident = row['ident']
            assert wildcards.ident.startswith(ident)
            url = row['genome_url']
            name = row['display_name']

            print(f"downloading genome for ident {ident}/{name} from NCBI...",
                file=sys.stderr)
            with open(output.genome, 'wb') as outfp:
                with urllib.request.urlopen(url) as response:
                    content = response.read()
                    outfp.write(content)
                    print(f"...wrote {len(content)} bytes to {output.genome}",
                        file=sys.stderr)

# summarize_reads_info
rule summarize_reads_info_wc:
    input:
        kmers = outdir + "/abundtrim/{sample}.abundtrim.fq.gz.kmer-report.txt",
        reads = outdir + "/abundtrim/{sample}.abundtrim.fq.gz.reads-report.txt",
        gather_report = outdir + "/gather/{sample}.prefetch.report.txt",
    output:
        outdir + '/{sample}.info.yaml',
    run:
        d = {}
        with open(str(input.kmers), 'rt') as fp:
            kmers = fp.readlines()[3].strip()
            kmers = kmers.split()[-1]
            d['kmers'] = int(kmers)

        with open(str(input.reads), 'rt') as fp:
            r = csv.DictReader(fp)
            row = next(iter(r))
            d['n_reads'] = int(row['n_reads'])
            d['n_bases'] = int(row['n_bases'])

        with open(str(input.gather_report), 'rt') as fp:
            r = csv.DictReader(fp)
            row = next(iter(r))
            d['total_hashes'] = int(row['total_hashes'])
            d['known_hashes'] = int(row['known_hashes'])
            d['unknown_hashes'] = int(row['unknown_hashes'])

        d['sample'] = wildcards.sample

        with open(str(output), 'wt') as fp:
            import yaml
            yaml.dump(d, fp)


# create a spacegraphcats config file
rule create_sgc_conf_wc:
    input:
        csv = outdir + "/gather/{sample}.gather.csv",
        queries = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz")
    output:
        conf = outdir + "/sgc/{sample}.conf"
    run:
        query_list = "\n- ".join(input.queries)
        with open(output.conf, 'wt') as fp:
           print(f"""\
catlas_base: {wildcards.sample}
input_sequences:
- {outdir}/abundtrim/{wildcards.sample}.abundtrim.fq.gz
ksize: 31
radius: 1
search:
- {query_list}
""", file=fp)
