from pathlib import Path
from typing import List, Optional

import click
import functools
import os
from tqdm import tqdm

from grid.cli import rich_click
from grid.cli.utilities import is_experiment
from grid.sdk import env
from grid.sdk.runs import Run
from grid.sdk.experiments import Experiment
from grid.sdk.artifacts import list_artifacts, download_artifacts


def with_regular_certs(f):
    """ helper decorator to download artifacts without any custom cert. We can override with USE_CUSTOM_CERTS_ALWAYS (in ../globals.py) """

    # keep track of custom certs
    ssl_custom_cert_envs = "REQUESTS_CA_BUNDLE", "SSL_CERT_FILE"
    custom_certs = dict()

    if not env.SSL_CA_CERT:
        for k in ssl_custom_cert_envs:
            v = os.environ.get(k, None)
            if v is not None:
                custom_certs[k] = v

    @functools.wraps(f)
    def wrapper(*args, **kwargs):

        try:
            for e in custom_certs.keys():
                os.environ.pop(e, None)
            return f(*args, **kwargs)
        finally:
            for e, v in custom_certs.items():
                os.environ[e] = v

    return wrapper


@rich_click.command()
@click.option(
    "--download_dir",
    type=click.Path(exists=False, file_okay=False, dir_okay=True),
    required=False,
    default="./grid_artifacts",
    help="Download directory that will host all artifact files."
)
@click.option(
    "-m",
    "--match_regex",
    type=str,
    default="",
    help="Only show artifacts that match this regex filter. Best if quoted."
)
@rich_click.argument(
    "runs_or_experiments", type=str, required=True, nargs=-1, help="The run or experiment to download artifacts for."
)
def artifacts(runs_or_experiments: List[str], download_dir: Optional[str] = None, match_regex: str = "") -> None:
    """Downloads artifacts for a given run or experiments.

    This will download artifacts generated by the runs / experiments.
    Regex filtering is used to determine which artifacts to download.
    """
    click.echo("Downloading artifacts. This command may take a while")
    experiments: List[Experiment] = []
    for element in runs_or_experiments:
        if is_experiment(element):
            experiment = Experiment(name=element)
            if not experiment.exists:
                click.echo(f"Experiment {element} does not exist - can not get artifacts")
                continue
            experiments.append(experiment)

        else:
            run = Run(name=element)
            if not run.exists:
                click.echo(f"Run {element} does not exist - can not get artifacts")
                continue

            for experiment in run.experiments:
                if not experiment.exists:
                    click.echo(f"Experiment {experiment.name} does not exist - can not get artifacts")
                    continue
                experiments.append(experiment)

    # exp.download_artifacts generates two progress bars:
    # 1st: progress of all artifacts in an experiment
    # 2nd: download progress for a single artifact in MB/s
    # This will add a 3rd at the bottom: experiments finished downloading
    exp_iter = tqdm(experiments, unit="experiment", position=2)
    for experiment in exp_iter:
        exp_iter.set_description(experiment.name)
        artifacts_to_download = list_artifacts(experiment.name, artifact_regex=match_regex)
        download_artifacts(artifacts_to_download, destination=download_dir)
    click.echo("Done downloading artifacts!")
