import time

if not config:
    configfile: os.path.join(os.getcwd(), 'config-atacseq.yaml')

RUN_ID = str(config['run_id'])
TEMPLATES = config['templates']


orig_base = os.path.join(TEMPLATES, 'snakefile_base', 'snakefile_base_atacseq.py')
orig2 = os.path.join(TEMPLATES, 'snakefile_base', 'creature.py')
orig3 = os.path.join(TEMPLATES, 'snakefile_base', 'functions.py')
target_base_dir = os.path.join(os.getcwd(), 'snakefile_base_' + RUN_ID)
target_base = os.path.join(target_base_dir, 'snakefile_base_atacseq.py')
if not os.path.isdir(target_base_dir):
    os.system(
        "mkdir -p {target_base_dir}; cp {orig_base} {target_base_dir}; cp {orig2} {target_base_dir}; cp {orig3} {target_base_dir}; touch __init__.py".format(
            target_base_dir=target_base_dir, orig_base=orig_base, orig2=orig2, orig3=orig3))
    time.sleep(10)

include: target_base


"""
Rules:
=======
"""

rule rule_all:
    input:
        os.path.join(ROOT_OUT_DIR, 'Done_'+ RUN_ID + '.txt'),


rule rule_1_cutadapt:
    input:
        *get_fastq(paired_end=True) # In ATAC-seq always will be paired end - 2 files
    output:
        out1=temp(os.path.join(ROOT_OUT_DIR, '1_cutadapt', '{sample}_R1.fastq')),
        out2=temp(os.path.join(ROOT_OUT_DIR, '1_cutadapt', '{sample}_R2.fastq'))
    params:
        out_sum = os.path.join(ROOT_OUT_DIR, '1_cutadapt', '{sample}.cutadapt.txt')
    log:
        cut = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '1_cutadapt.{sample}.txt'),
    threads: 1
    resources:
        mem_mb_per_thread=100,
        mem_mb_total=100
    shell:
        '''
        #Running in parallel is not supported on Python 2
        {CUTADAPT_EXE} -q 25 -a "A{{10}}" -a "T{{10}}" -A "A{{10}}" -A "T{{10}}" -a {CUTADAPT_ADAP1} -A {CUTADAPT_ADAP2} --minimum-length 30 -o {output.out1} -p {output.out2} {input[0]} {input[1]} > {params.out_sum} 2> {log.cut};
        touch {output.out1}.deleted {output.out2}.deleted
        '''


rule rule_2_fastqc:
    input:
        rules.rule_1_cutadapt.output
    output:
        os.path.join(ROOT_OUT_DIR, '2_fastqc', '{sample}', '{sample}_R1_fastqc', 'fastqc_data.txt')
    params:
        output_dir = os.path.join(ROOT_OUT_DIR, '2_fastqc', '{sample}')
    threads: 5
    resources:
        mem_mb_per_thread=200,
        mem_mb_total=1000
    log:
        os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '2_fastqc.{sample}.txt')
    shell:'''
        mkdir -p {params.output_dir}
        {FASTQC_EXE} --extract -o {params.output_dir} -f fastq --threads {threads} {input} > {log} 2>&1
    '''


rule rule_3_multiqc:
    input:
        expand(os.path.join(ROOT_OUT_DIR, '2_fastqc', '{sample}', '{sample}_R1_fastqc', 'fastqc_data.txt'), sample=SAMPLES)
    output:
        os.path.join(ROOT_OUT_DIR, '3_multiqc','multiqc_report.html')
    params:
        input_dir = os.path.join(ROOT_OUT_DIR, '2_fastqc'),
        output_dir = os.path.join(ROOT_OUT_DIR, '3_multiqc')
    threads: 1
    resources:
        mem_mb_per_thread=200,
        mem_mb_total=200
    log:
        os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME,'3_multiqc.txt')
    shell:
        '''
        {MULTIQC_EXE} -o {params.output_dir} {params.input_dir} > {log} 2>&1
        '''


rule rule_4_mapping:
    input:
        rules.rule_1_cutadapt.output
    output:
        temp(os.path.join(ROOT_OUT_DIR, '4_mapping', '{sample}.sam'))
    params:
        max_fragment_length=2000,
        stat_file=os.path.join(ROOT_OUT_DIR, '4_mapping', '{sample}.stat')
    log:
        os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '4_mapping.{sample}.txt'),
    threads: 20
    resources:
        mem_mb_per_thread=300,
        mem_mb_total=60000
    shell:
        '''
        {BOWTIE2_EXE} -X {params.max_fragment_length} --local -p {threads} -x {GENOME} -1 {input[0]} -2 {input[1]} -S {output} > {params.stat_file} 2> {log}
        touch {output}.deleted
        '''



rule rule_5_process_alignment:
    input:
        os.path.join(ROOT_OUT_DIR, '4_mapping', '{sample}.sam')
    output:
        rm_mito = temp(os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_rm_mito.bam')),
        rm_not_uniq = temp(os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_rm_not_uniq.bam')),
        sorted = temp(os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_sorted.bam')),
        rm_dup = os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_rm_dup.bam'),
    params:
        rm_dup_metrics = os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_rm_dup_metrics.txt'),
        rm_mito_statistics = os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_rm_mito_statistics.txt'),
        rm_not_uniq_statistics = os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_rm_not_uniq_statistics.txt'),
        rm_dup_statistics = os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_rm_dup_statistics.txt')
    log:
        rm_mito = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '5_proc_align_rm_mito.{sample}.txt'),
        rm_mito_statistics = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '5_proc_align_rm_mito_stat.{sample}.txt'),
        rm_not_uniq = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '5_proc_align_rm_not_uniq.{sample}.txt'),
        rm_not_uniq_statistics = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '5_proc_align_rm_not_uniq_stat.{sample}.txt'),
        sort = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '5_proc_align_sort.{sample}.txt'),
        rm_dup = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '5_proc_align_rm_dup.{sample}.txt'),
        index = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '5_proc_align_index.{sample}.txt'),
        rm_dup_statistics = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '5_proc_align_rm_dup_stat.{sample}.txt')
    threads: 5
    resources:
        mem_mb_per_thread=6000,
        mem_mb_total=30000
    shell:
        '''
        grep -v 'chrM' {input} > {output.rm_mito} 2> {log.rm_mito}
        {SAMTOOLS_EXE} flagstat {output.rm_mito} > {params.rm_mito_statistics} 2> {log.rm_mito_statistics}
        {SAMTOOLS_EXE} view -b -h -F 4 -f 0x2 {output.rm_mito} > {output.rm_not_uniq} 2> {log.rm_not_uniq}
        {SAMTOOLS_EXE} flagstat {output.rm_not_uniq} > {params.rm_not_uniq_statistics} 2> {log.rm_not_uniq_statistics}
        {JAVA} -Djava.io.tmpdir={TEMP_DIR} -XX:ParallelGCThreads={threads} -jar {PICARD_EXE} SortSam SO=coordinate I={output.rm_not_uniq} O={output.sorted} > {log.sort}  2>&1
        {JAVA} -Djava.io.tmpdir={TEMP_DIR} -XX:ParallelGCThreads={threads} -jar {PICARD_EXE} MarkDuplicates INPUT={output.sorted} OUTPUT={output.rm_dup} M={params.rm_dup_metrics} REMOVE_DUPLICATES=true > {log.rm_dup} 2>&1
        {SAMTOOLS_EXE} index {output.rm_dup} > {log.index} 2>&1
        {SAMTOOLS_EXE} flagstat {output.rm_dup} > {params.rm_dup_statistics} 2> {log.rm_dup_statistics}
        touch {output.rm_mito}.deleted {output.rm_not_uniq}.deleted {output.sorted}.deleted
        '''

if RUN_NGSPLOT:
    rule rule_6_ngs_plot:
        input:
            os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_rm_dup.bam')
        output:
            genbody_png = os.path.join(ROOT_OUT_DIR, '6_ngs_plot', '{sample}_genbody.png')
        params:
            genbody_pdf = os.path.join(ROOT_OUT_DIR, '6_ngs_plot', '{sample}_genbody'),
            tss_pdf = os.path.join(ROOT_OUT_DIR, '6_ngs_plot', '{sample}_tss'),
            picard_pdf = os.path.join(ROOT_OUT_DIR, '6_ngs_plot', '{sample}_picard.pdf'),
            picard_hist = os.path.join(ROOT_OUT_DIR, '6_ngs_plot', '{sample}_picard.hist'),
            tss_png = os.path.join(ROOT_OUT_DIR, '6_ngs_plot', '{sample}_tss.png'),
            picard_png = os.path.join(ROOT_OUT_DIR, '6_ngs_plot', '{sample}_picard.png'),
            output_dir = os.path.join(ROOT_OUT_DIR, '6_ngs_plot')
        log:
            genbody_pdf = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '6_ngs_plot_genbody_pdf.{sample}.txt'),
            tss_pdf = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '6_ngs_plot_tss_pdf.{sample}.txt'),
            picard_pdf = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '6_ngs_plot_picard_pdf.{sample}.txt'),
            genbody_png = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '6_ngs_plot_genbody_png.{sample}.txt'),
            tss_png = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '6_ngs_plot_tss_png.{sample}.txt'),
            picard_png = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '6_ngs_plot_picard_png.{sample}.txt')
        threads: 1
        resources:
            mem_mb_per_thread=25000,
            mem_mb_total=25000
        shell:'''
                cd {params.output_dir}
                {NGS_PLOT_EXE} -G {NGSPLOT_GENOME} -R genebody -C {input} -O {params.genbody_pdf} -T {wildcards.sample} > {log.genbody_pdf} 2>&1
                {NGS_PLOT_EXE} -G {NGSPLOT_GENOME} -R tss -C {input} -O {params.tss_pdf} -T {wildcards.sample} > {log.tss_pdf} 2>&1
                cd {ROOT_OUT_DIR}
                {JAVA} -Djava.io.tmpdir={TEMP_DIR} -jar {PICARD_EXE} CollectInsertSizeMetrics I={input} MINIMUM_PCT=0.5 O={params.picard_hist} H={params.picard_pdf} W=1000 > {log.picard_pdf} 2>&1
                {GS_EXE} -dNOPAUSE -dBATCH -sDEVICE=pngalpha -sOutputFile={output.genbody_png} -r144 {params.genbody_pdf}.avgprof.pdf > {log.genbody_png} 2>&1
                {GS_EXE} -dNOPAUSE -dBATCH -sDEVICE=pngalpha -sOutputFile={params.tss_png} -r144 {params.tss_pdf}.avgprof.pdf > {log.tss_png} 2>&1
                {GS_EXE} -dNOPAUSE -dBATCH -sDEVICE=pngalpha -sOutputFile={params.picard_png} -r144 {params.picard_pdf} > {log.picard_png} 2>&1
            '''


# Generate a bam file with nucleosome free regions
rule rule_7_nucleosome_free:
    input:
        os.path.join(ROOT_OUT_DIR, '5_process_alignment', '{sample}_rm_dup.bam')
    output:
        free = temp(os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', '{sample}_nucl_free.sam')),
        sam2bam = temp(os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', '{sample}_nucl_free.bam')),
        statistics = os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', '{sample}_nucl_free.statistics'),
    log:
        header = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '7_nucleosome_free_header.{sample}.txt'),
        free = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '7_nucleosome_free.{sample}.txt'),
        sam2bam = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '7_nucleosome_free_sam2bam.{sample}.txt'),
        index = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '7_nucleosome_free_index.{sample}.txt'),
        statistics = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '7_nucleosome_free_statistics.{sample}.txt'),
    threads: 1
    resources:
        mem_mb_per_thread=100,
        mem_mb_total=100
    shell:
        '''
        {SAMTOOLS_EXE} view -H {input} > {output.free} 2> {log.header}
        {SAMTOOLS_EXE} view {input} | awk -F "\t" '{{if (($9 > -120) && ($9 < 120))  print $_}}' >> {output.free} 2> {log.free}
        {SAMTOOLS_EXE} view -h -b {output.free} > {output.sam2bam} 2> {log.sam2bam}
        {SAMTOOLS_EXE} index {output.sam2bam} > {log.index} 2>&1
        {SAMTOOLS_EXE} flagstat {output.sam2bam} > {output.statistics} 2> {log.statistics}
        touch {output.free}.deleted {output.sam2bam}.deleted
        '''

if TSS_FILE:
    rule rule_8_tss_count:
        input:
            rules.rule_7_nucleosome_free.output.sam2bam
        output:
            bed = temp(os.path.join(ROOT_OUT_DIR, '8_tss_count', '{sample}.bed')),
            sort = temp(os.path.join(ROOT_OUT_DIR, '8_tss_count', '{sample}_sorted.bed')),
            count = os.path.join(ROOT_OUT_DIR, '8_tss_count', '{sample}_tss_counts.txt')
        log:
            bam2bed = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '8_tss_count_bam2bed.{sample}.txt'),
            sort = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '8_tss_count_sort.{sample}.txt'),
            counting = os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '8_tss_count_counting.{sample}.txt')
        threads: 1
        resources:
            mem_mb_per_thread=15000,
            mem_mb_total=15000
        shell:
            '''
            {BEDTOOLS_EXE} bamtobed -i {input} > {output.bed} 2> {log.bam2bed}
            sort -k1,1 -k2,2n {output.bed} > {output.sort} 2> {log.sort}
            {BEDTOOLS_EXE} coverage -counts -sorted -a {TSS_FILE} -b {output.sort} > {output.count} 2> {log.counting}
            touch {output.bed}.deleted {output.sort}.deleted
            '''


if CONTROL:
    rule rule_9_call_peak_with_control:
        input:
            expand(os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', '{sample}_nucl_free.bam'), sample=SAMPLES),
            lambda wildcards: expand(os.path.join(ROOT_OUT_DIR, '6_ngs_plot', '{sample}_genbody.png'), sample=SAMPLES) if RUN_NGSPLOT else []
        output:
            os.path.join(ROOT_OUT_DIR, '9_call_peak', '{treat}_vs_{control}_peaks.broadPeak')
        params:
            output_dir = os.path.join(ROOT_OUT_DIR, '9_call_peak'),
            treat_file = os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', '{treat}_nucl_free.bam'),
            control_file = os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', '{control}_nucl_free.bam')
        log:
            os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '9_call_peak.{treat}.txt')
        threads: 1
        resources:
            mem_mb_per_thread=3000,
            mem_mb_total=3000
        run:
            if wildcards.treat in COMBINE_SAMPLES_DB:
                samples = COMBINE_SAMPLES_DB[wildcards.treat]
                orig_input_files = [os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', sample + '_nucl_free.bam') for sample in samples]
                combined_input_file = os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', wildcards.treat + '_nucl_free.bam')
                shell("{SAMTOOLS_EXE} merge -f {combined_input_file} {orig_input_files}")
            if wildcards.control in COMBINE_SAMPLES_DB:
                samples = COMBINE_SAMPLES_DB[wildcards.treat]
                orig_input_files = [os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', sample + '_nucl_free.bam') for sample in samples]
                combined_input_file = os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', wildcards.control + '_nucl_free.bam')
                shell("{SAMTOOLS_EXE} merge -f {combined_input_file} {orig_input_files}")
            shell("{MACS2_EXE} callpeak -t {params.treat_file} -c {params.control_file} --bw 120 -B -f BAMPE --SPMR -B -g {MACS_GENOME_SIZE} --nomodel -n {wildcards.treat}_vs_{wildcards.control} --shift -50 --extsize 100 --broad  --keep-dup all --outdir {params.output_dir} > {log} 2>&1")
else:
    rule rule_9_call_peak_without_control:
        input:
            os.path.join(ROOT_OUT_DIR, '7_nucleosome_free', '{sample}_nucl_free.bam')
        output:
            os.path.join(ROOT_OUT_DIR, '9_call_peak', '{sample}_peaks.broadPeak')
        params:
            output_dir = os.path.join(ROOT_OUT_DIR, '9_call_peak')
        log:
            os.path.join(ROOT_OUT_DIR, LOG_DIR_NAME, '9_call_peak.{sample}.txt')
        threads: 1
        resources:
            mem_mb_per_thread=3000,
            mem_mb_total=3000
        shell:"""
            {MACS2_EXE} callpeak -t {input} --bw 120 -B -f BAMPE --SPMR -B -g {MACS_GENOME_SIZE} --nomodel -n {wildcards.sample} --shift -50 --extsize 100 --broad  --keep-dup all --outdir {output} > {log} 2>&1
            """

rule rule_10_reports:
    input:
        lambda wildcards: expand(os.path.join(ROOT_OUT_DIR, '9_call_peak', '{sample}_peaks.broadPeak'), sample=SAMPLES) if not CONTROL else expand(os.path.join(ROOT_OUT_DIR, '9_call_peak', '{treat}_vs_{control}_peaks.broadPeak'), zip, treat=TREATMENT, control=CONTROL),
        lambda wildcards: expand(os.path.join(ROOT_OUT_DIR, '6_ngs_plot', '{sample}_genbody.png'), sample=SAMPLES) if (RUN_NGSPLOT and not CONTROL) else [],
        lambda wildcards: expand(os.path.join(ROOT_OUT_DIR, '8_tss_count', '{sample}.bed'), sample=SAMPLES) if (TSS_FILE) else [],
        rules.rule_3_multiqc.output
    output:
        os.path.join(ROOT_OUT_DIR, 'Done_'+ RUN_ID + '.txt'),
    params:
        output_dir = os.path.join(ROOT_OUT_DIR, '10_reports')
    threads: 1
    resources:
        mem_mb_per_thread=50,
        mem_mb_total=50
    shell:
        '''
        touch {ROOT_OUT_DIR}/Done_{RUN_ID}.txt
        '''

