import asyncio
import copy
from enum import Enum
import hashlib
import io
import json
import logging
import os
from pathlib import Path
import sys
import tarfile
import tempfile
from typing import (
    Any,
    Dict,
    List,
    Optional,
    Tuple,
)
from urllib.parse import urlparse

import aiobotocore  # type: ignore
from filelock import FileLock
import ray._private.runtime_env as ray_runtime_env
from ray.job_config import JobConfig

import anyscale
from anyscale.api import get_api_client
from anyscale.credentials import load_credentials
from anyscale.sdk.anyscale_client.sdk import AnyscaleSDK
from anyscale.utils.ray_utils import _dir_travel, _get_excludes  # type: ignore

logger = logging.getLogger(__name__)

# TODO (yic) set it
S3BUCKET = ""

SINGLE_FILE_MINIMAL = 25 * 1024 * 1024  # 25MB
DELTA_PKG_LIMIT = 100 * 1024 * 1024  # 100MB

# This is the key prefix in s3
KEY_PREFIX: str = ""


DIR_META_FILE_NAME = ".anyscale_runtime_dir"

event_loop = asyncio.new_event_loop()


class PackagePrefix(Enum):
    """Theses are prefix of different packages. Packages are composed by
        base + [add] - del
    """

    BASE_PREFIX: str = "base_"
    ADD_PREFIX: str = "add_"
    DEL_PREFIX: str = "del_"


async def _get_object(  # type: ignore
    s3, key: str,
):
    obj = await s3.get_object(Bucket=S3BUCKET, Key=key)
    return obj["Body"]


async def _put_object(s3, key: str, local_path: str,) -> None:  # type: ignore
    await s3.put_object(Body=open(local_path, "rb"), Bucket=S3BUCKET, Key=key)


async def _object_exists(s3, key: str,) -> bool:  # type: ignore
    try:
        await s3.head_object(
            Bucket=S3BUCKET, Key=key,
        )
    except Exception:
        return False
    return True


def _object_exists_sync(key: str,) -> bool:
    async def helper(key: str,) -> bool:
        session = aiobotocore.get_session()
        async with session.create_client("s3") as s3:
            exists = await _object_exists(s3, key)
            return exists

    return event_loop.run_until_complete(helper(key))


def _hash_file_contents(local_path: Path, hasher: "hashlib._Hash") -> "hashlib._Hash":
    if local_path.is_file():
        buf_size = 4096 * 1024
        with local_path.open("rb") as f:
            data = f.read(buf_size)
            while len(data) != 0:
                hasher.update(data)
                data = f.read(buf_size)
    return hasher


def _entry_hash(local_path: Path, tar_path: Path) -> bytes:
    """Calculate the hash of a path

    If it's a directory:
        dir_hash = hash(tar_path)
    If it's a file:
        file_hash = dir_hash + hash(file content)
    """
    hasher = hashlib.md5()
    hasher.update(str(tar_path).encode())
    return _hash_file_contents(local_path, hasher).digest()


def _xor_bytes(left: Optional[bytes], right: Optional[bytes]) -> Optional[bytes]:
    """Combine two hashes that are commutative.

    We are combining hashes of entries. With this function, the ordering of the
    entries combining doesn't matter which avoid creating huge list and sorting.
    """
    if left and right:
        return bytes(a ^ b for (a, b) in zip(left, right))
    return left or right


class _PkgURI(object):
    """This class represents an internal concept of URI.

    An URI is composed of: pkg_type + hash_val. pkg_type is an entry of
    PackagePrefix.

    For example, `add_<content_hash>` or `del_<content_hash>`.

    The purpose of this class is to make the manipulation of URIs easier.
    """

    @staticmethod
    def from_uri(pkg_uri: str) -> "_PkgURI":
        """Constructor of _PkgURI from URI."""
        uri = urlparse(pkg_uri)
        assert uri.scheme == "s3"
        name = uri.netloc
        if name.startswith(PackagePrefix.BASE_PREFIX.value):
            pkg_type = PackagePrefix.BASE_PREFIX
        elif name.startswith(PackagePrefix.ADD_PREFIX.value):
            pkg_type = PackagePrefix.ADD_PREFIX
        elif name.startswith(PackagePrefix.DEL_PREFIX.value):
            pkg_type = PackagePrefix.DEL_PREFIX
        else:
            assert False
        hash_val = name[len(pkg_type.value) :]
        return _PkgURI(pkg_type, hash_val)

    def __init__(
        self, pkg_type: Optional[PackagePrefix] = None, hash_val: Optional[str] = None
    ):
        """Constructor of _PkgURI."""

        if pkg_type is None and hash_val is None:
            return
        assert pkg_type is not None and hash_val is not None
        self._pkg_type = pkg_type
        self._hash_val = hash_val
        # Right now we only support s3. Hard code it here.
        self._scheme = "s3"

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, _PkgURI):
            return NotImplemented
        return (
            self._scheme == other._scheme
            and self._pkg_type == other._pkg_type
            and self._hash_val == other._hash_val
        )

    def uri(self) -> str:
        return self._scheme + "://" + self.name()

    def s3_path(self) -> str:
        return KEY_PREFIX + self.name()

    def local_path(self) -> Path:
        return Path(ray_runtime_env.PKG_DIR) / self.name()

    def name(self) -> str:
        assert isinstance(self._pkg_type.value, str)
        return self._pkg_type.value + self._hash_val

    def is_base_pkg(self) -> bool:
        return self._pkg_type == PackagePrefix.BASE_PREFIX

    def is_add_pkg(self) -> bool:
        return self._pkg_type == PackagePrefix.ADD_PREFIX

    def is_del_pkg(self) -> bool:
        return self._pkg_type == PackagePrefix.DEL_PREFIX


class _Pkg:
    """Class represent a package.

    A package is composed by pkg_uri + contents."""

    def __init__(
        self,
        working_dir: Path,
        pkg_type: PackagePrefix,
        hash_val: str,
        contents: List[Tuple[Path, bytes]] = [],
    ):
        self._uri = _PkgURI(pkg_type, hash_val)
        self._working_dir = working_dir
        self._contents = contents

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, _Pkg):
            return NotImplemented
        return (
            self._uri == other._uri
            and self._working_dir == other._working_dir
            and set(self._contents) == set(other._contents)
        )

    def create_tar_file(self) -> str:
        """Create a physical package.

        It'll to through contents and put every file into a tar file.
        This function will return the path to the physical package. The caller
        need to clean it up once finish using it.

        TODO (yic): Support steaming way to update the pkg.
        """
        empty_file = io.BytesIO()
        with tempfile.NamedTemporaryFile(delete=False) as f:
            with tarfile.open(fileobj=f, mode="w:") as tar:
                for (to_path, _) in self._contents:
                    file_path = self._working_dir / to_path
                    if self._uri.is_del_pkg():
                        info = tarfile.TarInfo(str(to_path))
                        info.size = 0
                        tar.addfile(info, empty_file)
                    else:
                        tar.add(file_path, to_path)
            return f.name

    def name(self) -> str:
        return self._uri.name()

    def s3_path(self) -> str:
        return KEY_PREFIX + self._uri.name()

    def uri(self) -> str:
        return self._uri.uri()

    def update_meta(self) -> None:
        """Persist the meta info into disk.

        This function is for base package only. It'll transform the contents
        into meta and writing to a disk file.
        """
        assert self._uri.is_base_pkg()
        files = {}
        for (to_path, hash_val) in self._contents:
            files[str(to_path)] = hash_val.hex()
        meta = {"hash_val": self._uri._hash_val, "files": files}
        with (self._working_dir / DIR_META_FILE_NAME).open("w") as meta_file:
            meta_file.write(json.dumps(meta))


def _read_dir_meta(
    work_dir: Path, _skip_check: bool = False
) -> Optional[Dict[str, Any]]:
    """Read meta from the meta file.

    Meta file is composed by json string. The structure of the file is like this:
    {
       "hash_val": "base_pkg_hash",
       "files": {
         "file1": "hash1",
         "file2": "hash2",
       }
    }
    """
    meta_file_path = work_dir / DIR_META_FILE_NAME
    if not meta_file_path.exists():
        return None
    meta_valid = True
    if not meta_file_path.is_file():
        meta_valid = False
    try:
        meta = json.loads(meta_file_path.read_text())
        if "hash_val" not in meta or "files" not in meta:
            meta_valid = False
    except Exception:
        meta_valid = False
        meta = None
    if not meta_valid or not isinstance(meta, dict):
        raise ValueError(
            f"Invalid meta file: f{meta_file_path}. This should be a"
            "file managed by anyscale. The content of the file is broken."
            "Please consider delete/move it and retry"
        )
    if _skip_check:
        return meta
    pkg_uri = _PkgURI(PackagePrefix.BASE_PREFIX, meta["hash_val"])
    try:
        if _object_exists_sync(pkg_uri.s3_path()):
            # For mypy warning. It will raise exception before.
            assert isinstance(meta, dict)
            return meta
        else:
            return None
    except Exception:
        logger.error("Failed to check the meta existence. Treat it as not existing")
        return None


def _get_base_pkg_from_meta(working_dir: Path, meta: Dict[str, Any]) -> _Pkg:
    files = []
    for (f, h) in meta["files"].items():
        files.append((Path(f), bytes.fromhex(h)))
    return _Pkg(working_dir, PackagePrefix.BASE_PREFIX, meta["hash_val"], files)


""""
The following functions are related with package splitting and uploading
"""


def _calc_pkg_for_working_dir(
    working_dir: Path, excludes: List[str], meta: Optional[Dict[str, Any]]
) -> Tuple[
    Optional[_Pkg], Optional[_Pkg], Optional[_Pkg], List[_Pkg]
]:  # base  # add  # del  # files
    """Split the working directory and calculate the pkgs.

    The algorithm will go with this way:
       - create delta if we have base.
       - (or) create the base if no base.
       - if delta is too big (DELTA_PKG_LIMIT, default as 100MB), we update the
         base.

    All big files will be put into a separate package.

    Args:
        working_dir (Path): The working directory to split.
        excludes (List[str]): The pattern to exclude from. It follows gitignore.
        meta (Optional[Dict[str, Any]]): This is the base meta we have.

    Returns:
        List of packages(base_pkg, add_pkg, del_pkg, pkgs)
    """
    # TODO (yic) Try to avoid calling ray internal API
    if ".anyscale_runtime_dir" not in excludes:
        excludes.append(".anyscale_runtime_dir")
    excludes = _get_excludes(working_dir, excludes)
    pkgs = []
    files = []
    hash_val = None
    base_files = copy.deepcopy(meta["files"]) if meta is not None else {}
    delta_size = 0

    all_files = []
    all_hash_val = None

    def handler(path: Path) -> None:
        # These nonlocals are output of the traveling.
        #   hash_val: the hash value of delta package
        #   all_files: contain all files for the new base which will be used if we re-base
        #   all_hash_val: the hash value of new base
        #   pkgs: contains big files which size is greater than SINGLE_FILE_MINIMAL
        #   delta_size: the size of the delta
        nonlocal hash_val, all_files, all_hash_val, files, pkgs, delta_size
        if path.is_dir() and next(path.iterdir(), None) is not None:
            return
        to_path = path.relative_to(working_dir)
        file_hash = _entry_hash(path, to_path)
        entry = (to_path, file_hash)
        # If it's a big file, put it into a separate pkg
        if path.is_file() and path.stat().st_size >= SINGLE_FILE_MINIMAL:
            pkg = _Pkg(working_dir, PackagePrefix.ADD_PREFIX, file_hash.hex(), [entry])
            pkgs.append(pkg)
        else:  # If it's an empty directory or just a small file
            if base_files.pop(str(to_path), None) != file_hash.hex():
                files.append(entry)
                hash_val = _xor_bytes(hash_val, file_hash)
                delta_size += path.stat().st_size
            # We also put it into all_files in case the delta is too big and we'd need
            # to change the base
            all_files.append(entry)
            all_hash_val = _xor_bytes(all_hash_val, file_hash)

    # Travel the dir with ray runtime env's api
    _dir_travel(working_dir, [excludes] if excludes else [], handler)

    # If there is no base or the delta is too big, we'll update the base
    if meta is None or delta_size > DELTA_PKG_LIMIT:
        base_pkg = None
        if all_hash_val is not None:
            base_pkg = _Pkg(
                working_dir, PackagePrefix.BASE_PREFIX, all_hash_val.hex(), all_files
            )
        return (base_pkg, None, None, pkgs)

    # Otherwise reuse base pkg
    base_pkg = _get_base_pkg_from_meta(working_dir, meta)
    add_pkg = None
    if hash_val is not None:
        add_pkg = _Pkg(working_dir, PackagePrefix.ADD_PREFIX, hash_val.hex(), files)
    # If there is some files existing in base, it means this files have been deleted.
    # In this case, we need to generate a del pkg.
    del_pkg = None
    if len(base_files) != 0:
        hash_val = None
        del_files = []
        for (del_file, _) in base_files.items():
            hasher = hashlib.md5()
            hasher.update(del_file.encode())
            hash_val = _xor_bytes(hash_val, hasher.digest())
            del_files.append((Path(del_file), bytes()))
        assert hash_val is not None
        del_pkg = _Pkg(working_dir, PackagePrefix.DEL_PREFIX, hash_val.hex(), del_files)
    return (base_pkg, add_pkg, del_pkg, pkgs)


def rewrite_runtime_env_uris(job_config: JobConfig) -> None:
    """Rewriting the job_config to calculate the pkgs needed for this runtime_env"""

    # If the uris has been set, we'll use this directly
    if job_config.runtime_env.get("uris") is not None:
        return
    working_dir = job_config.runtime_env.get("working_dir")
    excludes = job_config.runtime_env.get("excludes") or []
    if working_dir is None:
        return
    working_path = Path(working_dir).absolute()
    assert working_path.is_dir()
    meta = _read_dir_meta(Path(working_dir))
    # get working_dir pkgs
    (base_pkg, add_pkg, del_pkg, file_pkgs) = _calc_pkg_for_working_dir(
        working_path, excludes, meta
    )
    # Put all uris into `uris` field.
    job_config.runtime_env["uris"] = [p.uri() for p in file_pkgs]
    job_config.runtime_env["uris"].extend(
        [p.uri() for p in [base_pkg, add_pkg, del_pkg] if p is not None]
    )

    # Put the pkg into `_pkg_contents` field for later uploading
    job_config.runtime_env["_pkg_contents"] = (base_pkg, add_pkg, del_pkg, file_pkgs)


async def _upload_pkg(session: aiobotocore.AioSession, file_pkg: _Pkg) -> bool:
    """Upload the package if it doesn't exist in s3"""
    async with session.create_client("s3") as s3:
        # The content hash is encoded in the path of the package, so we don't
        # need to compare the contents of the package. Although, there might be
        # hash collision, given that it's using md5 128bits and we also put
        # files from different orgs into different paths, it won't happen in
        # the real world.
        exists = await _object_exists(s3, file_pkg.s3_path())
        if exists:
            return False
        local_pkg = file_pkg.create_tar_file()
        await _put_object(s3, file_pkg.s3_path(), local_pkg)
        os.unlink(local_pkg)
        return True


async def _upload_file_pkgs(
    session: aiobotocore.AioSession, file_pkgs: List[_Pkg]
) -> None:
    tasks = [_upload_pkg(session, pkg) for pkg in file_pkgs]
    done, pending = await asyncio.wait(tasks)
    assert not pending


def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None:
    """If the uris doesn't exist, we'll upload them"""
    uris = job_config.runtime_env.get("uris")
    if not uris:
        return
    if "_skip_uploading" in job_config.runtime_env:
        logger.info("Skipping uploading for preset uris")
        return
    base_pkg, add_pkg, del_pkg, file_pkgs = job_config.runtime_env["_pkg_contents"]
    assert base_pkg is not None or (add_pkg is None and del_pkg is None)
    session = aiobotocore.get_session()
    all_pkgs = [p for p in ([base_pkg, add_pkg, del_pkg] + file_pkgs) if p is not None]
    if len(all_pkgs) != 0:
        event_loop.run_until_complete(_upload_file_pkgs(session, all_pkgs))
    if base_pkg is not None:
        base_pkg.update_meta()


"""
The following functions are related with package downloading and construction
"""


async def _fetch_dir_pkg(session: aiobotocore.AioSession, pkg_uri: _PkgURI) -> None:
    local_path = pkg_uri.local_path()
    with FileLock(str(local_path) + ".lock"):
        if local_path.exists():
            assert local_path.is_dir()
        else:
            async with session.create_client("s3") as s3:
                local_path.mkdir()
                streambody = await _get_object(s3, pkg_uri.s3_path())
                # TODO (yic) Use streaming mode instead of downloading everything
                with tempfile.NamedTemporaryFile() as tmp_tar:
                    async for data in streambody.iter_chunks():
                        tmp_tar.write(data)
                    tmp_tar.flush()
                    with tarfile.open(tmp_tar.name, mode="r:*") as tar:
                        tar.extractall(local_path)


async def _fetch_uris(session: aiobotocore.AioSession, pkg_uris: List[_PkgURI]) -> None:
    tasks = [_fetch_dir_pkg(session, pkg_uri) for pkg_uri in pkg_uris]
    if len(tasks) != 0:
        done, pending = await asyncio.wait(tasks)
        assert not pending


def _is_fs_leaf(path: Path) -> bool:
    return path.is_file() or (path.is_dir() and next(path.iterdir(), None) is None)


def _link_fs_children(from_path: Path, to_path: Path) -> None:
    assert from_path.is_dir() and to_path.is_dir()
    for f in from_path.glob("*"):
        (to_path / f.name).symlink_to(f)


def _merge_del(working_dir: Path, del_path: Path) -> None:
    """Recursively iterate through `del_path` and delete it from `working_dir`"""
    assert working_dir.is_dir() and not working_dir.is_symlink()
    for f in del_path.glob("*"):
        to_path = working_dir / f.name
        # If the target is a leaf, we can just delete it
        if _is_fs_leaf(f):
            to_path.unlink()
        else:
            # If the to_path is a symlink, it means it's a link from shared
            # resources. For isolation, we create a new dir and link all
            # children to the physical dir
            if to_path.is_symlink():
                true_path = to_path.resolve()
                to_path.unlink()
                to_path.mkdir()
                _link_fs_children(true_path, to_path)
            # We go one step deeper in the dir here.
            # to_path is working_dir/some_path
            # f is del_path/some_path
            _merge_del(to_path, f)


def _merge_add(working_dir: Path, delta_path: Path) -> None:
    """Recursively iterate through delta_path and merge it to working_dir"""
    assert working_dir.is_dir() and not working_dir.is_symlink()
    for f in delta_path.glob("*"):
        to_path = working_dir / f.name
        # We link it to the target directly if the target doesn't exist
        if not to_path.exists():
            to_path.symlink_to(f)
            continue
        else:
            # If the target exist, it means we might need to overwrite it
            if to_path.is_file() or (to_path.is_dir() and f.is_file()):
                # Working dir is not symlink, which means it's only visible
                # to current job. So we can delete the file directly from it.
                to_path.unlink()
                to_path.symlink_to(f)
            else:
                # If the target is a symlink, we need to create a folder and
                # link all the children to this new created one.
                if to_path.is_symlink():
                    true_path = to_path.resolve()
                    to_path.unlink()
                    to_path.mkdir()
                    _link_fs_children(true_path, to_path)
                _merge_add(to_path, f)


def _construct_from_uris(pkg_uris: List[_PkgURI]) -> Path:
    """Construct the working directory from the pkgs"""

    # Firstly, we split `pkg_uris` into three parts: base, adds and del.
    base_pkg = None
    add_pkg = []
    del_pkg = None
    for p in pkg_uris:
        if p.is_base_pkg():
            assert base_pkg is None
            base_pkg = p
        elif p.is_add_pkg():
            add_pkg.append(p)
        elif p.is_del_pkg():
            assert del_pkg is None
            del_pkg = p
        else:
            assert False
    tmp_working_dir = tempfile.NamedTemporaryFile(
        prefix="_ray_working_dir_", delete=False, dir=ray_runtime_env.PKG_DIR
    )
    tmp_working_dir.close()
    working_dir = Path(tmp_working_dir.name)
    # if we only have base_pkg, we'll use it directly
    # otherwise, soft link them to the temp dir
    if base_pkg is not None and len(add_pkg) == 0 and del_pkg is None:
        # We only have working dir, so link it directly
        working_dir.unlink()
        working_dir.symlink_to(base_pkg.local_path())
    else:
        working_dir.unlink()
        working_dir.mkdir()
        if base_pkg is not None:
            _link_fs_children(base_pkg.local_path(), working_dir)
    # If there is delete pkg, merge it
    if del_pkg:
        _merge_del(working_dir, del_pkg.local_path())
    # merge all add pkg
    for pkg in add_pkg:
        _merge_add(working_dir, pkg.local_path())
    return working_dir


def ensure_runtime_env_setup(uris: List[str]) -> Optional[str]:
    """Download uris from s3 if it doesn't exist locally."""
    if len(uris) == 0:
        return None
    pkg_uris = [_PkgURI.from_uri(uri) for uri in uris]
    session = aiobotocore.get_session()
    task = _fetch_uris(session, pkg_uris)
    event_loop.run_until_complete(task)
    working_dir = _construct_from_uris(pkg_uris)
    sys.path.insert(0, str(working_dir))
    return str(working_dir)


"""
The following functions are anyscale related functions
"""


def register_runtime_env(env: Dict[str, Any]) -> List[str]:
    job_config = JobConfig(runtime_env=env)
    rewrite_runtime_env_uris(job_config)
    upload_runtime_env_package_if_needed(job_config)
    uris = job_config.runtime_env["uris"]
    assert isinstance(uris, list)
    return uris


def setup_token() -> None:
    sdk = AnyscaleSDK(
        auth_token=load_credentials(), host=f"{anyscale.conf.ANYSCALE_HOST}/ext"
    )

    api_client = get_api_client()
    org_id = api_client.get_user_info_api_v2_userinfo_get().result.organization_ids[0]
    api_client = get_api_client()
    token = sdk.get_organization_temporary_credentials(org_id).result

    global KEY_PREFIX
    KEY_PREFIX = org_id + "/"

    os.environ["AWS_ACCESS_KEY_ID"] = token.aws_access_key_id
    os.environ["AWS_SECRET_ACCESS_KEY"] = token.aws_secret_access_key
    os.environ["AWS_SESSION_TOKEN"] = token.aws_session_token


def runtime_env_setup() -> None:
    setup_token()
    ray_runtime_env.rewrite_runtime_env_uris = rewrite_runtime_env_uris
    ray_runtime_env.upload_runtime_env_package_if_needed = (
        upload_runtime_env_package_if_needed
    )
    ray_runtime_env.ensure_runtime_env_setup = ensure_runtime_env_setup
