
import pandas as pd
from lisa.core.data_interface import DataInterface
from lisa import FromCoverage, FromGenes, FromRegions
import os

configfile: "config.yaml"

cistrome_metadata = pd.read_csv(config['cistrome_data'], sep = '\t').set_index('DCid')
cistrome_metadata.index = cistrome_metadata.index.astype(str)

existing_data = DataInterface(config['species'], window_size = config['window_size'], download_if_not_exists=False,
    make_new = config['start_new'], load_genes=False)

SPECIES_CONVERSION = {'hg38' : 'Homo sapiens', 'mm10' : 'Mus musculus'}

candidate_samples = cistrome_metadata[
    cistrome_metadata.species == SPECIES_CONVERSION[config['species']]
]

candidate_samples['passed_qc'] = (candidate_samples.UniquelyMappedRatio > 0.3) \
    & (candidate_samples.UniquelyMappedReadsNumber >= 1e6) \
    & (candidate_samples.PBC > 0.5) \
    & (candidate_samples.FRiP > 0.003) \
    & (candidate_samples.MergedPeaksUnionDHSRatio > 0.3) \
    & (candidate_samples.PeaksNumber > 2000)

DONE_PATH = "data/done/{species}/{window_size}/{technology}_{dataset_id}.done"
FACTOR_PRECURSOR_PATH = "data/indices/{species}/{window_size}/{technology}_{dataset_id}.txt"
PROFILE_PRECURSOR_PATH = "data/coverage_arrays/{species}/{window_size}/{technology}_{dataset_id}.npy"
motif_metadata_path = "data/{species}/motif_metadata.tsv".format(species = config['species'])

def get_unprocessed_ids(set1, set2):
    return list(
        set(set1).difference(set(set2))
    )

def get_dataset_touchfile(technology, dataset_id, is_profile = False):

    format_dict = dict(species = config['species'],
            window_size = config['window_size'],
            technology = str(technology),
            dataset_id = str(dataset_id),
        )

    if not is_profile:
        precursor_path = FACTOR_PRECURSOR_PATH
    else:
        precursor_path = PROFILE_PRECURSOR_PATH

    return DONE_PATH.format(**format_dict), precursor_path.format(**format_dict)

#___ WHICH FACTOR HITS TO PROCESS ____
factor_targets, factor_precursors = [],[]

if config['process_factors']:

    factor_hit_samples = candidate_samples[
        (candidate_samples.factor_type.isin(config['factor_hit_names'])) \
        & (candidate_samples.summit.notna()) \
        & (candidate_samples.passed_qc)    
    ].index.values

    processed_factor_hits = existing_data.list_binding_datasets('ChIP-seq')

    unprocessed_factors = get_unprocessed_ids(factor_hit_samples, processed_factor_hits)

    if len(unprocessed_factors) > 0:
        factor_targets, factor_precursors = list(zip(
            *[get_dataset_touchfile('ChIP-seq', target) 
            for target in unprocessed_factors]
        ))
    

#___ WHICH MOTIFS TO PROCESS ____
motif_targs, motif_precursors = [],[]

if config['process_motifs']:

    with open(config['motif_data'].format(species = config['species']), 'r') as f:
        candidate_motifs = [x.strip() for x in f]

    processed_motifs = existing_data.list_binding_datasets('Motifs')

    unprocessed_motifs = get_unprocessed_ids(candidate_motifs, processed_motifs)

    if len(unprocessed_motifs) > 0:
        motif_targs, motif_precursors = list(zip(
            *[get_dataset_touchfile('Motifs', target)
            for target in unprocessed_motifs]
        ))
    

#___ WHICH ACCESSIBILITY PROFILES TO PROCESS___

profile_targets = []
profile_precursors = []

if config['process_profiles']:

    for profile_type in config['chromatin_profiles']:

        profile_samples = candidate_samples[
            (candidate_samples.factor == profile_type) \
            & (candidate_samples.passed_qc) \
            & (candidate_samples.bigwig.notna())
        ].index.values

        processed_profile_samples = existing_data.list_profiles(profile_type)

        unprocessed_samples = get_unprocessed_ids(profile_samples, processed_profile_samples)

        if len(unprocessed_samples) > 0:

            _targets, _precursors = list(zip(
                *[get_dataset_touchfile(profile_type, target, is_profile=True) 
                for target in unprocessed_samples]
            ))
        
            profile_targets.extend(_targets)
            profile_precursors.extend(_precursors)

TARGETS = list(factor_targets) + list(motif_targs) + list(profile_targets)

# ___ RULES ____

rule all:
    input:
        TARGETS

rule download_motifs:
    output:
        temp("data/motif_beds/{species}/{dataset_id}.bed.gz")
    shell:
        'wget http://expdata.cmmt.ubc.ca/JASPAR/downloads/UCSC_tracks/2020/{wildcards.species}/{wildcards.dataset_id}.tsv.gz -O {output}'


rule get_motif_hits:
    input:
        "data/motif_beds/{species}/{dataset_id}.bed.gz"
    output:
        indices=temp("data/indices/{species}/{window_size}/Motifs_{dataset_id}.txt")
    params:
        motif_metadata=motif_metadata_path
    shell:
        'python ./bin/calc_motif_bins.py {wildcards.species} {input} -w {wildcards.window_size} -o {output.indices} >> {params.motif_metadata}'


rule get_chip_hits:
    input:
        lambda wildcards : cistrome_metadata.loc[wildcards.dataset_id].summit
    output:
        temp("data/indices/{species}/{window_size}/ChIP-seq_{dataset_id}.txt")
    shell:
        'python ./bin/map_peaks_to_indices.py {wildcards.species} {wildcards.window_size} "{input}" "{output}"'


if config['process_profiles']:

    rule get_coverage_array:
        input:
            lambda wildcards : cistrome_metadata.loc[wildcards.dataset_id].bigwig
        output:
            "data/coverage_arrays/{species}/{window_size}/{technology}_{dataset_id}.npy"
        params:
            name=lambda wildcards : "data/coverage_arrays/{species}/{window_size}/{technology}_{dataset_id}"\
                .format(species = wildcards.species, window_size = wildcards.window_size, technology = wildcards.technology, dataset_id = wildcards.dataset_id)
        shell:
            "python bin/compute_coverage_array.py {wildcards.species} {input} -n {params.name}"

    rule format_h5:
        input:
            indices=list(factor_precursors) + list(motif_precursors),
            arrays=list(profile_precursors)
        output:
            touch(TARGETS)
        params:
            motifs=motif_metadata_path,
            motifs_temp=motif_metadata + "_temp"
        shell:
            'python ./bin/append_peak_indices.py {config[species]} {config[window_size]} {config[cistrome_data]} {params.motifs} {input.indices} && '
            'python ./bin/append_profiles.py {config[species]} {config[window_size]} {config[cistrome_data]} {input.arrays}'

else:

    rule format_h5:
        input:
            indices=list(factor_precursors) + list(motif_precursors)
        output:
            touch(TARGETS)
        params:
            motifs=motif_metadata_path
        shell:
            'python ./bin/append_peak_indices.py {config[species]} {config[window_size]} {config[cistrome_data]} {params.motifs} {input.indices}'