#!/usr/bin/python

"""

Snakefile for pyinseq main pipeline

"""

from pathlib import Path
# Setup logger before importing modules
from pyinseq.logger import pyinseq_logger, add_fileHandler, add_streamHandler
pyinseq_logger.setup_logger()
# Module imports
import pyinseq
from pyinseq.settings import Settings
from pyinseq.pipeline import pipeline_summarize
from pyinseq.map_reads import summarize_mapping
from pyinseq.demultiplex import demultiplex_fastq
from pyinseq.gbk_convert import build_fna_and_table_files
from pyinseq.utils import read_gene_file, copy_file, load_pickle
from pyinseq.process_mapping import map_sites, map_genes, build_gene_table

# Path to environment folder
ENVS = Path(pyinseq.__file__).parent.joinpath("envs")

# Load pickled settings
settings = load_pickle(f"results/{config['experiment']}/settings.pickle")
samples_dict = settings.samples_dict
output_dir = str(settings.path)

# Setup pyinseq logger and add file handler to snakemake
pyinseq_logger.add_logfile_paths(settings)
add_fileHandler(pyinseq_logger.logger, settings.log)
# Add io stream to snakemake for saving info
add_streamHandler(logger.logger, pyinseq_logger.snake_io)
## SNAKE MESSAGE TO LOGGER
logger.logger.info("(SNAKEMAKE INFO)")

rule all:
    input:
        f'{output_dir}/summary_gene_table.txt',
        expand(output_dir + "/raw_data/{sample}.fastq", sample=samples_dict.keys()),
        expand(output_dir + "/{sample}_trimmed.fastq", sample=samples_dict.keys()),
        expand(output_dir + "/{sample}_bowtie.txt", sample=samples_dict.keys()),
        expand(output_dir + "/{sample}_sites.txt", sample=samples_dict.keys()),
        expand(output_dir + "/{sample}_genes.txt", sample=samples_dict.keys()),
        f"{output_dir}/samples.txt",


rule genome_prep:
    input:
        genebank=f"{settings.reference_genome}"
    output:
        fna_file="{output_dir}/genome_lookup/genome.fna",
        ftt_file="{output_dir}/genome_lookup/genome.ftt",
    threads: settings.threads
    run:
        # Build needed genome tables
        build_fna_and_table_files(input.genebank, settings)
        pyinseq_logger.summarize_step()


rule bowtie_index:
    input:
        fna_file=f"{output_dir}/genome_lookup/genome.fna"
    output:
        "{output_dir}/genome_lookup/genome.1.ebwt",
    threads: settings.threads
    conda:
        str(ENVS.joinpath("bowtie.yaml"))
    shell:
        "bowtie-build -q {input.fna_file} {output_dir}/genome_lookup/genome"


rule demultiplex:
    input:
        fastq=f"{settings.reads}"
    output:
        expand(output_dir + "/raw_data/{sample}.fastq", sample=samples_dict.keys()),
        expand(output_dir + "/{sample}_trimmed.fastq", sample=samples_dict.keys())
    threads: 1
    run:
        # Run demultiplex
        barcode_read_count = demultiplex_fastq(input.fastq, samples_dict, settings)
        settings.dump_sample_dict_to_yml(barcode_read_count)
        pyinseq_logger.summarize_step()


rule bowtie_mapping:
    input:
        output_dir + "/{sample}_trimmed.fastq",
        f"{output_dir}/genome_lookup/genome.1.ebwt"
    output:
        output_dir + "/{sample}_bowtie.txt"
    threads: settings.threads
    conda:
        str(ENVS.joinpath("bowtie.yaml"))
    params:
        genome=f"{output_dir}/genome_lookup/genome",
        log=settings.log,
        summary_log=settings.summary_log
    script:
        f"../../map_reads.py"


rule map_sites:
    input:
        output_dir + "/{sample}_bowtie.txt"
    output:
        output_dir + "/{sample}_sites.txt",
    threads: settings.threads
    run:
        # Map sites
        samples_sites_count = map_sites(wildcards.sample, settings)
        settings.dump_sample_dict_to_yml(samples_sites_count)
        pyinseq_logger.summarize_step()


rule map_genes:
    input:
        output_dir + "/{sample}_sites.txt",
    output:
        output_dir + "/{sample}_genes.txt"
    threads: settings.threads
    run:
        # Map genes
        gene_dict = map_genes(wildcards.sample, settings)
        settings.dump_sample_dict_to_yml(gene_dict)
        pyinseq_logger.summarize_step()


rule build_gene_table:
    input:
        expand(output_dir + "/{sample}_genes.txt", sample=samples_dict.keys()),
    output:
        "{output_dir}/summary_gene_table.txt"
    threads: 1
    run:
        # Get gene_mappings dictionary
        gene_mappings = dict()
        for sample_gene_table in input:
            from pathlib import Path
            sample = Path(sample_gene_table).name.replace('_genes.txt', '')
            gene_mappings[sample] = read_gene_file(sample_gene_table)
        # Map sites
        build_gene_table(settings.organism, samples_dict, gene_mappings, settings.experiment)
        pyinseq_logger.summarize_step()


rule summarize:
    input:
        output_dir + "/summary_gene_table.txt"
    output:
        "{output_dir}/samples.txt"
    threads: 1
    run:
        pipeline_summarize(samples_dict, settings)
        pyinseq_logger.summarize_step()


