import os
import sys
import platform
import tarfile

from metaflow.datastore import MetaflowDataStore
from metaflow.datastore.datastore import TransformableObject
from metaflow.datastore.util.s3util import get_s3_client
from metaflow.decorators import StepDecorator
from metaflow.metaflow_config import DATASTORE_SYSROOT_LOCAL
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task

from metaflow import util

from .batch import Batch, BatchException
from metaflow.metaflow_config import ECS_S3_ACCESS_IAM_ROLE, BATCH_JOB_QUEUE

try:
    # python2
    from urlparse import urlparse
except:  # noqa E722
    # python3
    from urllib.parse import urlparse


class ResourcesDecorator(StepDecorator):
    name = "resources"
    defaults = {
        "cpu": "1",
        "gpu": "0",
        "memory": "4000",
    }

class BatchDecorator(StepDecorator):
    name = "batch"
    defaults = {
        "cpu": "1",
        "gpu": "0",
        "memory": "8000",
        "image": "python:%s" % platform.python_version(),
        "queue": BATCH_JOB_QUEUE,
        "iam_role": ECS_S3_ACCESS_IAM_ROLE
    }
    package_url = None
    package_sha = None
    run_time_limit = None

    def __init__(self, attributes=None, statically_defined=False):
        super(BatchDecorator, self).__init__(attributes, statically_defined)

    def step_init(self, flow, graph, step, decos, environment, datastore, logger):
        self.logger = logger
        self.environment = environment
        self.step = step
        for deco in decos:
            if isinstance(deco, ResourcesDecorator):
                for k, v in deco.attributes.items():
                    # we use the larger of @resources and @batch attributes
                    my_val = self.attributes.get(k)
                    if not (my_val is None and v is None):
                        self.attributes[k] = str(max(int(my_val or 0), int(v or 0)))
        self.run_time_limit = get_run_time_limit_for_task(decos)

    def runtime_init(self, flow, graph, package, run_id):
        self.flow = flow
        self.graph = graph
        self.package = package
        self.run_id = run_id

    def runtime_task_created(
        self, datastore, task_id, split_index, input_paths, is_cloned
    ):
        if datastore.TYPE != "s3":
            raise BatchException("The *@batch* decorator requires --datastore=s3.")

        if not is_cloned:
            self._save_package_once(datastore, self.package)

    def runtime_step_cli(self, cli_args, retry_count, max_user_code_retries):
        if retry_count <= max_user_code_retries:
            # after all attempts to run the user code have failed, we don't need
            # Batch anymore. We can execute possible fallback code locally.
            cli_args.commands = ["batch", "step"]
            cli_args.command_args.append(self.package_sha)
            cli_args.command_args.append(self.package_url)
            cli_args.command_options.update(self.attributes)
            cli_args.command_options["run-time-limit"] = self.run_time_limit
            cli_args.entrypoint[0] = sys.executable

    def task_pre_step(
            self, step_name, ds, meta, run_id, task_id, flow, graph, retry_count, max_retries):
        if meta.TYPE == 'local':
            self.ds_root = ds.root
        else:
            self.ds_root = None

    def task_finished(self, step_name, flow, graph, is_task_ok, retry_count, max_retries):
        if self.ds_root:
            # We have a local metadata service so we need to persist it to the datastore.
            # Note that the datastore is *always* s3 (see runtime_task_created function)
            with util.TempDir() as td:
                tar_file_path = os.path.join(td, 'metadata.tgz')
                with tarfile.open(tar_file_path, 'w:gz') as tar:
                    # The local metadata is stored in the local datastore
                    tar.add(DATASTORE_SYSROOT_LOCAL)
                # At this point we upload what need to s3
                s3, _ = get_s3_client()
                with open(tar_file_path, 'rb') as f:
                    path = os.path.join(
                        self.ds_root,
                        MetaflowDataStore.filename_with_attempt_prefix(
                            'metadata.tgz', retry_count))
                    url = urlparse(path)
                    s3.upload_fileobj(f, url.netloc, url.path.lstrip('/'))

    @classmethod
    def _save_package_once(cls, datastore, package):
        if cls.package_url is None:
            cls.package_url = datastore.save_data(package.sha, TransformableObject(package.blob))
            cls.package_sha = package.sha

