"""Snakefile used by Snakemake."""
from pathlib import Path

import snakemake.io
from rich import box
from rich.highlighter import ReprHighlighter
from rich.padding import Padding
from rich.table import Table
from rich.text import Text
from snakemake.logging import logger

from sparv import util
from sparv.core import config as sparv_config
from sparv.core import paths, registry, snake_utils, snake_prints
from sparv.core.console import console

# Remove Snakemake's default log handler
if config.get("run_by_sparv") and logger.log_handler and logger.log_handler[0] == logger.text_handler:
    logger.log_handler = []

# Don't do anything if no rule was specified
rule do_nothing:
    input: []

# ==============================================================================
# Dynamic Creation of Snakemake Rules
# ==============================================================================

def make_rules(config_missing: bool) -> None:
    """Load all Sparv modules and create Snakemake rules."""
    # Create rules for all available annotation functions
    for module_name in registry.modules:
        for f_name, f in registry.modules[module_name].functions.items():
            make_rule(module_name, f_name, f, config_missing)
    
    # Create custom rules
    for custom_rule_obj in sparv_config.get("custom_annotations", []):
        module_name, f_name = custom_rule_obj["annotator"].split(":")
        annotator = registry.modules[module_name].functions[f_name]
        make_rule(module_name, f_name, annotator, config_missing, custom_rule_obj)

    # Check and set rule orders
    ordered_rules = snake_utils.check_ruleorder(snake_storage)
    for rule1, rule2 in ordered_rules:
        # ruleorder:  rule1.rule_name > rule2.rule_name
        workflow.ruleorder(rule1.rule_name, rule2.rule_name)
    # Print ordered rules when in debug mode
    if config.get("debug") and ordered_rules:
        console.print("\n\n\n[b]ORDERED RULES:[/b]")
        for rule1, rule2 in ordered_rules:
            print("    • {} > {}".format(rule1.rule_name, rule2.rule_name))
        print()


def make_rule(module_name: str, f_name: str, annotator_info: dict, config_missing: bool = False,
              custom_rule_obj: dict = None) -> None:
    """Create single Snakemake rule."""
    # Init rule storage
    rule_storage = snake_utils.RuleStorage(module_name, f_name, annotator_info)

    # Process rule parameters and update rule storage
    create_rule = snake_utils.rule_helper(rule_storage, config, snake_storage, config_missing, custom_rule_obj)

    if create_rule:
        # Create a named Snakemake rule for annotator (unfortunately we cannot use the regular snakemake syntax for this)
        @workflow.rule(name=rule_storage.rule_name)
        @workflow.message(rule_storage.target_name)
        @workflow.input(rule_storage.inputs)
        @workflow.output(rule_storage.outputs)
        @workflow.params(module_name=rule_storage.module_name,
                         f_name=rule_storage.f_name,
                         parameters=snake_utils.get_parameters(rule_storage),
                         export_dirs=rule_storage.export_dirs)
        # We use "script" instead of "run" since with "run" the whole Snakefile would have to be reloaded for every
        # single job, due to how Snakemake creates processes for run-jobs.
        @workflow.script("run_snake.py")
        @workflow.run
        def __rule__(input_, output, params, wildcards, threads, resources, log, version, rule, conda_env, container_img,
                     singularity_args, use_singularity, env_modules, bench_record, jobid, is_shell, bench_iteration,
                     cleanup_scripts, shadow_dir, edit_notebook):
            script("run_snake.py", paths.sparv_path / "core", input_, output, params,
                   wildcards, threads, resources, log, config, rule, conda_env, container_img, singularity_args,
                   env_modules, bench_record, jobid, bench_iteration, cleanup_scripts, shadow_dir)

        # Create rule to run this annotation on all input files
        make_all_files_rule(rule_storage)


def make_all_files_rule(rule_storage: snake_utils.RuleStorage) -> None:
    """Create named rule to run an annotation on all input files."""
    # Only create rule when explicitly called
    if config.get("run_by_sparv") and rule_storage.target_name not in config.get("targets", []):
        return

    # Get Snakemake rule object
    sm_rule = getattr(rules, rule_storage.rule_name).rule

    dependencies = rule_storage.outputs if not rule_storage.abstract else rule_storage.inputs

    # Prepend work dir to paths if needed (usually included in the {doc} wildcard but here it needs to be explicit)
    rule_outputs = [paths.work_dir / o if not (paths.work_dir in o.parents or paths.export_dir in o.parents)
                    else o
                    for o in dependencies]

    # Expand {doc} wildcard to every corpus document
    rule_outputs = expand(rule_outputs,
                          doc=snake_utils.get_doc_values(config, snake_storage),
                          **snake_utils.get_wildcard_values(config))

    # Convert paths to IOFile objects so Snakemake knows which rule they come from (in case of ambiguity)
    rule_outputs = [snakemake.io.IOFile(f, rule=sm_rule) for f in rule_outputs]

    @workflow.rule(name=rule_storage.target_name)
    @workflow.input(rule_outputs)
    @workflow.norun()
    @workflow.run
    def __rule__(*_args, **_kwargs):
        pass

# Init the storage for some essential variables involving all rules
snake_storage = snake_utils.SnakeStorage()

# Find and load corpus config
config_missing = snake_utils.load_config(config)

# Find and load Sparv modules
registry.find_modules(find_custom=sparv_config.get("custom_annotations"))

# Let exporters and importers inherit config values from 'export' and 'import' sections
for module in registry.modules:
    for a in registry.modules[module].functions.values():
        if a["type"] == registry.Annotator.importer:
            sparv_config.inherit_config("import", module)
        elif a["type"] == registry.Annotator.exporter:
            sparv_config.inherit_config("export", module)

# Collect list of all explicitly used annotations (without class expansion)
for key in registry.annotation_sources:
    registry.explicit_annotations_raw.update(a[0] for a in util.parse_annotation_list(sparv_config.get(key, [])))

# Figure out classes from annotation usage
registry.find_implicit_classes()

# Collect list of all explicitly used annotations (with class expansion)
for key in registry.annotation_sources:
    registry.explicit_annotations.update(
        registry.expand_variables(a[0])[0]
        for a in util.parse_annotation_list(sparv_config.get(key, [])))

# Load modules and create automatic rules
make_rules(config_missing)

# Validate config usage in modules
sparv_config.validate_module_config()

# Validate config
sparv_config.validate_config()

# Get reverse_config_usage dict for look-ups
reverse_config_usage = snake_utils.get_reverse_config_usage()

# ==============================================================================
# Static Snakemake Rules
# ==============================================================================

# Rule to list all config options and their current values
rule config:
    run:
        if config.get("options"):
            out_conf = {}
            for k in config["options"]:
                out_conf[k] = sparv_config.get(k)
        else:
            out_conf = sparv_config.config
        snake_prints.prettyprint_yaml(out_conf)


# Rule to list all modules and annotations
rule modules:
    run:
        if config.get("types") or config.get("names"):
            snake_prints.print_module_info(config.get("types", []), config.get("names", []), snake_storage,
                                           reverse_config_usage)
        else:
            snake_prints.print_module_summary(snake_storage)


# Rule to list all annotation classes
rule classes:
    run:
        snake_prints.print_annotation_classes()


# Rule to list all annotation presets
rule presets:
    run:
        resolved_presets = dict(
            (preset, sparv_config.resolve_presets(sparv_config.presets[preset])) for preset in sparv_config.presets)
        snake_prints.prettyprint_yaml(resolved_presets)


# Rule to list all targets
rule list_targets:
    run:
        print()
        table = Table(title="Available rules", box=box.SIMPLE, show_header=False, title_justify="left")
        table.add_column(no_wrap=True)
        table.add_column()

        table.add_row("[b]Exports[/b]")
        for target, desc, _lang in sorted(snake_storage.export_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Installers[/b]")
        for target, desc in sorted(snake_storage.install_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Annotators[/b]")
        for target, desc in sorted(snake_storage.named_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Model Builders[/b]")
        for target, desc, _lang in sorted(snake_storage.model_targets):
            table.add_row("  " + target, desc)
        table.add_row()
        table.add_row("[b]Custom Rules[/b]")
        for target, desc in sorted(snake_storage.custom_targets):
            table.add_row("  " + target, desc)
        console.print(table)
        note = Text.from_markup("[i]Note:[/i] Custom rules need to be declared in the 'custom_annotations' section "
                                "of your corpus configuration before they can be used.")
        ReprHighlighter().highlight(note)
        console.print(Padding(note, (0, 4)))


# Rule to list all exports
rule list_exports:
    run:
        print()
        table = Table(title="Available corpus output formats (exports)", box=box.SIMPLE, show_header=False,
                      title_justify="left")
        table.add_column(no_wrap=True)
        table.add_column()

        for target, desc, language in sorted(snake_storage.export_targets):
            if not language or sparv_config.get("metadata.language") in language:
                table.add_row(target, desc)
        console.print(table)
        console.print("  Default: xml_export:pretty")


# Rule to list all input files
rule files:
    run:
        from rich.columns import Columns
        print("Available input files:\n")
        console.print(Columns(sorted(snake_utils.get_source_files(snake_storage.source_files)), column_first=True,
                              padding=(0, 3)))


# Rule to remove dirs created by Sparv
rule clean:
    run:
        import shutil
        to_remove = []
        if config.get("export") or config.get("all"):
            to_remove.append(paths.export_dir)
            assert paths.export_dir, "Export dir name not configured."
        if config.get("logs") or config.get("all"):
            to_remove.append(paths.log_dir)
            assert paths.log_dir, "Log dir name not configured."
        if config.get("all") or not (config.get("export") or config.get("logs")):
            to_remove.append(paths.work_dir)
            assert paths.work_dir, "Work dir name not configured."

        something_removed = False
        for d in to_remove:
            full_path = Path.cwd() / d
            if full_path.is_dir():
                shutil.rmtree(full_path)
                snake_utils.print_sparv_info(f"'{d}' directory removed")
                something_removed = True
        if not something_removed:
            snake_utils.print_sparv_info("Nothing to remove")


# Rule to list all available installations
rule list_installs:
    run:
        selected_install_outputs = set(snake_utils.get_install_outputs(snake_storage))
        selected_installations = [(t, d) for t, d in sorted(snake_storage.install_targets)
                                  if set(snake_storage.install_outputs[t]).intersection(selected_install_outputs)]
        other_installations = [(t, d) for t, d in sorted(snake_storage.install_targets)
                               if not set(snake_storage.install_outputs[t]).intersection(selected_install_outputs)]

        if selected_installations:
            print()
            table = Table(title="Selected installations", box=box.SIMPLE, show_header=False, title_justify="left")
            table.add_column(no_wrap=True)
            table.add_column()
            for target, desc in selected_installations:
                table.add_row(target, desc)
            console.print(table)

        if other_installations:
            print()
            if selected_installations:
                title = "Other available installations"
            else:
                title = "Available installations"
            table = Table(title=title, box=box.SIMPLE, show_header=False, title_justify="left")
            table.add_column(no_wrap=True)
            table.add_column()
            for target, desc in other_installations:
                table.add_row(target, desc)
            console.print(table)

        console.print("[i]Note:[/i] Use the 'install' section in your corpus configuration to select what "
                      "installations should be performed when running 'sparv install' without arguments.")


# Rule for making exports defined in corpus config
rule export_corpus:
    input:
        snake_utils.get_export_targets(snake_storage, rules, doc=snake_utils.get_doc_values(config, snake_storage),
                                       wildcards=snake_utils.get_wildcard_values(config))


# Rule for making installations
rule install_corpus:
    input:
        snake_utils.get_install_outputs(snake_storage, config.get("install_types"))


# Rule to list all models that can be built/downloaded
rule list_models:
    run:
        print()
        table = Table(title="Models for current language ({})".format(sparv_config.get("metadata.language")),
                      box=box.SIMPLE, show_header=False, title_justify="left")
        table.add_column(no_wrap=True)
        table.add_column()
        for target, desc, language in sorted(snake_storage.model_targets):
            if language and sparv_config.get("metadata.language") in language:
                table.add_row(target, desc)
        console.print(table)
        table = Table(title="Language-independent models", box=box.SIMPLE, show_header=False, title_justify="left")
        table.add_column(no_wrap=True)
        table.add_column()
        for target, desc, language in sorted(snake_storage.model_targets):
            if not language:
                table.add_row(target, desc)
        console.print(table)


# Rule to list all annotations files that can be created
rule list_files:
    run:
        outputs = set([(rule.type, o) for rule in snake_storage.all_rules
                       for o in (rule.outputs if not rule.abstract else rule.inputs)])

        print()
        print("This is a list of files than can be created by Sparv. Please note that wildcards must be replaced "
              "with paths.")
        print()
        console.print("[i]Annotation files[/i]\n")
        for i in sorted(o for t, o in outputs if t in ("annotator", "importer")):
            print("    {}".format(i))
        print()
        console.print("[i]Export files[/i]\n")
        for i in sorted(o for t, o in outputs if t == "exporter"):
            print("    {}".format(i))
        print()
        console.print("[i]Model files[/i]\n")
        for i in sorted(o for t, o in outputs if t == "modelbuilder"):
            print("    {}".format(i))
        print()
        console.print("[i]Installation files[/i]\n")
        for i in sorted(o for t, o in outputs if t == "installer"):
            print("    {}".format(i))


# Build all models. Build even the non-optional ones if force_optional_models = True.
rule build_models:
    input:
        snake_storage.model_outputs
