from __future__ import annotations

import contextlib
import datetime
import json
import logging
import shutil
from pathlib import Path
from typing import (
    Any,
    AsyncGenerator,
    AsyncIterator,
    Awaitable,
    BinaryIO,
    Callable,
    Optional,
    Type,
    TypeVar,
    Union,
    overload,
)

import aiohttp
import backoff
import sentry_sdk
from aiohttp import FormData, WSMessage
from pydantic import BaseModel, ValidationError, parse_obj_as

from .. import schemas
from ..schemas import AuthTokenUnion, Hyperparameter, MachineSize
from ..shared.utils import get_data_or_raise
from . import config
from .errors import (
    SlingshotClientHttpException,
    SlingshotConnectionError,
    SlingshotException,
    SlingshotJWSInvalidSignature,
    SlingshotJWTExpiredError,
    SlingshotUnauthenticatedError,
)
from .graphql import BaseGraphQLQuery, base_graphql, fragments, queries
from .graphql.queries import (
    ProjectSecretsQuery,
    ProjectSecretsResponse,
    RunByIdResponse,
    ServiceAccountWithProjectsResponse,
)
from .utils import gql_mount_spec_to_read_mount_spec

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=Union[BaseModel, str, bytes, tuple[Any, ...], list[Any], dict[str, Any]])
JSONType = dict[str, Any]
ParamsType = dict[str, Union[str, float, int]]


class Retry(Exception):
    pass


class SlingshotClient:
    def __init__(
        self,
        *,
        auth_token: AuthTokenUnion | None = None,
        slingshot_url: str = config.global_config.slingshot_backend_url,
        hasura_admin_secret: str | None = config.global_config.hasura_admin_secret,
        auto_setup_hook: Callable[[], Awaitable[None]] | None = None,
    ) -> None:
        self._slingshot_url = slingshot_url.rstrip("/")
        self._api_url = slingshot_url.rstrip("/") + "/api"
        self._graphql_url = slingshot_url.rstrip("/") + "/graphql/v1/graphql"
        self.auth_token: AuthTokenUnion | None = auth_token
        self._hasura_admin_secret = hasura_admin_secret

        self._session: aiohttp.ClientSession | None = None
        self._auto_setup_hook = auto_setup_hook
        self._is_setup = not auto_setup_hook  # False, unless auto_setup_hook is None
        self.project: schemas.Project | None = None

    @property
    def _headers(self) -> dict[str, str]:
        if self.auth_token is None:
            headers = {}
        else:
            # We need to set Cookie since Hasura only uses this header for auth
            headers = {"Cookie": f"token={self.auth_token.token}"}
        return headers

    @contextlib.asynccontextmanager
    async def use_http_session(self) -> AsyncGenerator[SlingshotClient, None]:
        """Optional: Use this to reuse a session across multiple requests."""
        async with _maybe_make_http_session(self._session) as session:
            self._session = session
            yield self

    async def _maybe_setup(self) -> None:
        if self._is_setup:
            return
        if self._auto_setup_hook is None:
            return
        await self._auto_setup_hook()

    async def make_request(
        self,
        url: str,
        *,
        method: str,
        response_model: Type[T] | None,
        params: ParamsType | None = None,
        json_data: JSONType | None = None,
        data: dict[str, Any] | FormData | BinaryIO | bytes | None = None,
        headers: dict[str, str] | None = None,
        timeout: datetime.timedelta | None = None,
        _setup: bool = True,
    ) -> T:
        # If the url is relative, we need to prepend the base url
        if not url.startswith("http"):
            url = f"{self._api_url.rstrip('/')}/{url.lstrip('/')}"
        if _setup:
            await self._maybe_setup()

        timeout = timeout or datetime.timedelta(seconds=60)
        headers = headers or {}
        headers = {**self._headers, **headers}  # The order matters here, so the caller can override headers
        logger.debug(f"Making a '{method}' request to '{url}'")
        try:
            async with _maybe_make_http_session(self._session) as session:
                async with session.request(
                    url=url,
                    method=method,
                    params=params,
                    json=json_data,
                    data=data,
                    headers=headers,
                    timeout=int(timeout.total_seconds()),
                    max_redirects=0,
                ) as resp:
                    logger.debug(f"Got response from '{url}': {resp.status}")
                    if 400 <= resp.status <= 600:
                        exception = await SlingshotClientHttpException.from_response(resp)
                        if 500 <= resp.status <= 600:
                            if self._slingshot_url != config.global_config.slingshot_local_url:
                                sentry_sdk.capture_exception(exception)
                        raise exception
                    if response_model == aiohttp.ClientResponse:
                        return resp
                    elif response_model is None:
                        return await resp.text()
                    elif response_model == bytes:
                        return await resp.content.read()
                    elif response_model == str:
                        return await resp.json()
                    resp_data: dict[str, Any] | FormData | None = await resp.json()
        except aiohttp.ClientConnectorError as e:
            if self._slingshot_url != config.global_config.slingshot_local_url:
                sentry_sdk.capture_exception(e, api_url=self._api_url)
            raise SlingshotConnectionError(self._api_url) from e
        try:
            return parse_obj_as(response_model, resp_data)
        except ValidationError as e:
            if self._slingshot_url != config.global_config.slingshot_local_url:
                with sentry_sdk.push_scope() as scope:
                    scope.set_extra("response", resp_data)
                    scope.set_extra("response_model", response_model)
                    sentry_sdk.capture_exception(e, api_url=self._api_url)
            raise SlingshotException(f"Unexpected response format. {e}: {resp_data}") from e

    async def make_graphql_request(
        self, gql_query: BaseGraphQLQuery[T], _setup: bool = True
    ) -> base_graphql.GraphQLResponse[T]:
        if _setup:
            await self._maybe_setup()
        query = gql_query.query
        variables = gql_query.variables
        response_model = gql_query.response_model
        type_: base_graphql.GraphQLResponse[T] = base_graphql.GraphQLResponse[response_model]  # type: ignore
        headers = self._headers
        if self._hasura_admin_secret:
            headers["x-hasura-admin-secret"] = self._hasura_admin_secret
        logger.debug(f"Making query request to {query.strip()} with variables {variables}")
        return await self.make_request(
            self._graphql_url,
            method="post",
            _setup=_setup,
            response_model=type_,
            json_data={"query": query, "variables": variables},
            headers=headers,
        )

    # noinspection PyUnresolvedReferences
    async def make_graphql_subscription_request(
        self, gql_query: BaseGraphQLQuery[T]
    ) -> AsyncGenerator[base_graphql.GraphQLSubscriptionResponse[T], None]:
        query = gql_query.query
        variables = gql_query.variables
        response_model = gql_query.response_model
        headers = self._headers
        if self._hasura_admin_secret:
            headers["x-hasura-admin-secret"] = self._hasura_admin_secret
        logger.debug(f"Making query request to {query.strip().splitlines()[0]}")
        type_: base_graphql.GraphQLSubscriptionResponse[T] = base_graphql.GraphQLSubscriptionResponse[response_model]  # type: ignore
        while True:
            async with aiohttp.ClientSession() as session:
                async with session.ws_connect(self._graphql_url) as ws:
                    await ws.send_json({"type": "connection_init", "payload": {"headers": headers}})
                    await ws.send_json(
                        {"id": "1", "type": "start", "payload": {"query": query, "variables": variables}}
                    )
                    msg: WSMessage
                    async for msg in ws:
                        logger.debug(f"Received message: {msg}")
                        if msg.type == aiohttp.WSMsgType.TEXT:
                            data = json.loads(msg.data)
                            if data["type"] == "connection_error":
                                raise SlingshotException(f"WebSocket connection error: {data['payload']}")
                            elif "errors" in data:
                                raise SlingshotException(f"GraphQL error: {data['errors']}")
                            elif data["type"] == "data":
                                yield parse_obj_as(type_, data).payload
                            else:
                                logger.debug(f"Received unexpected message: {data}")
                        elif msg.type == aiohttp.WSMsgType.ERROR:
                            raise SlingshotException(f"WebSocket error: {msg.data}")

            logger.info("Websocket connection closed -- retrying")

    @contextlib.asynccontextmanager
    async def async_http_get(self, url: str) -> AsyncIterator[aiohttp.ClientResponse]:
        """
        Make an async HTTP GET request to the given URL. This is useful for streaming downloads for large files.
        """
        async with _maybe_make_http_session(self._session) as session:
            async with session.get(url) as resp:
                if 400 <= resp.status <= 600:
                    raise await SlingshotClientHttpException.from_response(resp)
                yield resp


class SlingshotAPI:
    def __init__(self, client: SlingshotClient) -> None:
        self._client = client

    @contextlib.asynccontextmanager
    async def use_http_session(self) -> AsyncGenerator[SlingshotAPI, None]:
        """Optional: Use this to reuse a session across multiple requests."""
        async with self._client.use_http_session():
            yield self

    """
    Misc API methods
    """

    async def get_backend_version(self) -> str:
        """Get the current version of the backend."""
        return await self._client.make_request("version", method="get", response_model=str, _setup=False)

    async def list_machine_types(self) -> list[schemas.MachineTypeListItem]:
        """Get a list of available machine types."""
        return await self._client.make_request(
            "machine_types", method="get", response_model=list[schemas.MachineTypeListItem]
        )

    """
    Auth API methods
    """

    async def user_login(self, auth0_token: str) -> schemas.AuthToken:
        """Sign in with an Auth0 CLI token."""
        auth_token_resp: schemas.AuthTokenResponse = await self._client.make_request(
            f"auth/token",
            method="post",
            response_model=schemas.AuthTokenResponse,
            _setup=False,
            json_data={"token": auth0_token, "cli": True},
        )
        auth_token = get_data_or_raise(auth_token_resp)
        logger.info("Signed in successfully")
        return auth_token

    async def sa_login(self, api_key: str) -> schemas.ServiceAccountToken:
        """Sign in with a Slingshot project API key. If no API key is provided, the API key will be read from the
        environment variable `SLINGSHOT_API_KEY`"""
        try:
            sa_token_resp: schemas.Response[schemas.ServiceAccountToken] = await self._client.make_request(
                "/auth/service_account/token",
                method="post",
                _setup=False,
                response_model=schemas.Response[schemas.ServiceAccountToken],
                headers={"token": api_key},
            )
        except SlingshotClientHttpException as e:
            if e.status == 401 and (msg := e.json and e.json.get("error")):
                # Usually this means the API key is invalid
                raise SlingshotException(msg)
            raise e
        return get_data_or_raise(sa_token_resp)

    async def get_auth0_cli_metadata(self) -> schemas.Auth0MetadataResponse:
        """Get metadata for the Auth0 CLI."""
        return await self._client.make_request(
            url=f"auth/auth0_cli", method="get", response_model=schemas.Auth0MetadataResponse, _setup=False
        )

    """
    User API methods
    """

    async def me_user(self, user_id: str) -> fragments.UserWithProjects:
        """Get the current user."""
        resp = await self._client.make_graphql_request(queries.UserWithProjectsQuery(user_id=user_id))
        if resp.errors:
            if SlingshotJWTExpiredError.graphql_message in resp.errors[0].message:
                raise SlingshotJWTExpiredError()
            if SlingshotUnauthenticatedError.graphql_message in resp.errors[0].message:
                raise SlingshotUnauthenticatedError()
            if SlingshotJWSInvalidSignature.graphql_message in resp.errors[0].message:
                raise SlingshotJWSInvalidSignature()
            else:
                raise SlingshotException(resp.errors[0].message)
        if not (data := resp.data):
            raise SlingshotException("No user found with given id")
        if not (user := data.users_by_pk):
            raise SlingshotException("No user found with given id")
        return user

    async def me_service_account(self, service_account_id: str) -> fragments.ServiceAccountWithProjects:
        """Get the current service account."""
        resp: base_graphql.GraphQLResponse[
            ServiceAccountWithProjectsResponse
        ] = await self._client.make_graphql_request(
            queries.ServiceAccountWithProjectsQuery(service_account_id=service_account_id)
        )
        if resp.errors:
            if SlingshotJWTExpiredError.graphql_message in resp.errors[0].message:
                raise SlingshotJWTExpiredError()
            if SlingshotUnauthenticatedError.graphql_message in resp.errors[0].message:
                raise SlingshotUnauthenticatedError()
            if SlingshotJWSInvalidSignature.graphql_message in resp.errors[0].message:
                raise SlingshotJWSInvalidSignature()
            else:
                raise SlingshotException(resp.errors[0].message)
        if not (data := resp.data):
            raise SlingshotException("Service account not found")
        if not (service_account := data.service_accounts_by_pk):
            raise SlingshotException("Service account not found")
        return service_account

    async def update_ssh_public_key(self, key: str) -> None:
        """Update the current user's SSH public key."""
        await self._client.make_request(
            url=f"user/me/ssh_public_key",
            method="put",
            json_data={"ssh_public_key": key},
            response_model=schemas.ResponseOK,
        )

    """
    Get API methods
    """

    async def get_billing_line_items_by_app_id(self, app_instance_id: str) -> list[fragments.BillingLineItem]:
        query = queries.BillingLineItemsByAppIdQuery(app_instance_id=app_instance_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data.billingLineItems

    async def get_billing_line_items_by_deployment_id(self, deployment_id: str) -> list[fragments.BillingLineItem]:
        query = queries.BillingLineItemsByDeploymentIdQuery(deployment_id=deployment_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data.billingLineItems

    async def get_billing_line_items_by_run_id(self, run_id: str) -> list[fragments.BillingLineItem]:
        query = queries.BillingLineItemsByRunIdQuery(run_id=run_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data.billingLineItems

    async def get_project_by_id(self, project_id: str, *, _setup: bool = True) -> fragments.ProjectFields | None:
        """Get a project by id."""
        query = queries.ProjectByIdQuery(project_id=project_id)
        resp = await self._client.make_graphql_request(query, _setup=_setup)
        if not resp.data:
            return None
        project = resp.data.projects_by_pk
        return project

    async def get_latest_app_instance_for_app_spec(self, app_spec_id: str) -> fragments.AppInstance | None:
        """Get the latest app instance for an app spec."""
        query = queries.LatestAppInstanceForAppSpecQuery(app_spec_id=app_spec_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        app_instance = data.app_instances[0] if data and data.app_instances else None
        return app_instance

    async def get_app_instance(self, app_instance_id: str, project_id: str) -> fragments.AppInstance | None:
        """Get an app instance by id."""
        query = queries.AppInstanceQuery(app_instance_id=app_instance_id, project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return (data and data.app_instances and data.app_instances[0]) or None

    async def get_app_spec_by_id(self, app_spec_id: str, project_id: str) -> fragments.AppSpec | None:
        """Get an app spec by id."""
        query = queries.AppSpecByIdQuery(app_spec_id=app_spec_id, project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        assert len(data.app_specs) == 1
        return (data and data.app_specs[0]) or None

    async def get_app_spec_by_name(self, app_spec_name: str, project_id: str) -> fragments.AppSpec | None:
        """Get an app spec by name."""
        query = queries.AppSpecByNameQuery(app_spec_name=app_spec_name, project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()

        # If there's no matching app spec for this name, then data.app_specs is an empty list.
        return (data and data.app_specs and data.app_specs[0]) or None

    @overload
    async def get_run(self, *, run_id: str, project_id: str) -> fragments.Run | None:
        ...

    @overload
    async def get_run(self, *, run_name: str, project_id: str) -> fragments.Run | None:
        ...

    async def get_run(
        self, *, run_id: str | None = None, run_name: str | None = None, project_id: str
    ) -> fragments.Run | None:
        """Get a run by id or name."""
        if run_id is None and run_name is None:
            raise ValueError("Either run_id or run_name must be specified")
        if run_id is not None and run_name is not None:
            raise ValueError("Only one of run_id or run_name can be specified")
        if run_id is not None:
            query_by_id: BaseGraphQLQuery[RunByIdResponse] = queries.RunByIdQuery(run_id=run_id)
            resp_by_id = await self._client.make_graphql_request(query_by_id)
            data_by_id: queries.RunByIdResponse = resp_by_id.get_data_or_raise()
            return data_by_id.run
        if run_name is not None:
            query_by_name = queries.RunByNameForProjectQuery(run_name=run_name, project_id=project_id)
            resp_by_name = await self._client.make_graphql_request(query_by_name)
            data_by_name: queries.RunsForProjectResponse = resp_by_name.get_data_or_raise()
            runs = data_by_name.runs
            if len(runs) == 0:
                return None
            if len(runs) > 1:
                raise SlingshotException(f"Found more than one run with name {run_name}")
            return runs[0]
        return None

    async def get_deployment(self, deployment_name: str, *, project_id: str) -> fragments.AppSpec | None:
        """Get a deployment by name."""
        query = queries.DeploymentSpecByNameQuery(app_spec_name=deployment_name, project_id=project_id)

        resp = await self._client.make_graphql_request(query)
        data: queries.AppSpecsResponse = resp.get_data_or_raise()
        if not data.app_specs:
            return None
        return data.app_specs[0]

    async def get_deployment_latencies(
        self, deployment_id: str, *, project_id: str
    ) -> schemas.UsageBinsLatencyQuantiles:
        """Get the latencies for a deployment."""
        resp = await self._client.make_request(
            url=f"project/{project_id}/deploy/{deployment_id}/latencies",
            method="get",
            response_model=schemas.UsageBinsLatencyQuantilesResponse,
        )
        if resp.error:
            raise SlingshotException(resp.error)

        return get_data_or_raise(resp)

    async def get_environment_spec(
        self, execution_environment_spec_id: str, *, project_id: str
    ) -> fragments.ExecutionEnvironmentSpec | None:
        """Get an execution environment spec by id."""
        query = queries.ExecutionEnvironmentSpecByIdQuery(execution_environment_spec_id=execution_environment_spec_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data and data.execution_environment_specs_by_pk

    async def get_blob_artifact_by_id(self, blob_artifact_id: str) -> fragments.BlobArtifact | None:
        """Get a blob artifact by id."""
        query = queries.BlobArtifactByIdQuery(blob_artifact_id=blob_artifact_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data.blob_artifacts_by_pk

    async def get_blob_artifact_by_name(
        self, blob_artifact_name: str, *, project_id: str
    ) -> fragments.BlobArtifact | None:
        """Get a blob artifact by id."""
        query = queries.BlobArtifactByNameQuery(blob_artifact_name=blob_artifact_name, project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data.blob_artifacts[0] if data.blob_artifacts else None

    async def get_exec_env(self, exec_env_id: str) -> fragments.ExecutionEnvironment | None:
        """Get an execution environment by id."""
        query = queries.ExecutionEnvironmentByIdQuery(exec_env_id=exec_env_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        if not data.execution_environments_by_pk:
            return None

        return data.execution_environments_by_pk

    async def get_latest_source_codes_for_project(self, project_id: str) -> fragments.SourceCodeArtifact | None:
        """Get the latest source code artifact for a project."""
        query = queries.LatestSourceCodeForProjectQuery(project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        if not data or not data.projects_by_pk or not data.projects_by_pk.source_codes:
            return None
        return data.projects_by_pk.source_codes[0]

    """
    List API methods
    """

    async def list_app_specs(self, project_id: str) -> list[fragments.AppSpec]:
        """List all app specs for a project."""
        query = queries.AppSpecsForProjectQuery(project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data.app_specs

    async def list_runs(self, project_id: str) -> list[fragments.Run]:
        """List all runs for a project."""
        query = queries.RunsForProjectQuery(project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data and data.runs or []

    async def list_app_instances_by_type(
        self, app_type: schemas.AppType, *, project_id: str
    ) -> list[fragments.AppInstance]:
        """List all app instances for a project with a given type."""
        query = queries.AppInstancesByAppTypeQuery(app_type=app_type.value, project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data and data.app_instances or []

    async def list_environment_specs(self, *, project_id: str) -> list[fragments.ExecutionEnvironmentSpec]:
        """List all execution environment specs for a project."""
        query = queries.ExecutionEnvironmentSpecsForProjectQuery(project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data and data.execution_environment_specs or []

    async def list_artifacts(self, tag: str | None = None, *, project_id: str) -> list[fragments.BlobArtifact]:
        """Get the latest artifacts for a project."""
        if tag is not None:
            query = queries.LatestBlobArtifactsForProjectByTagQuery(project_id=project_id, tag=tag)
        else:
            query = queries.LatestBlobArtifactsForProjectQuery(project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        if not data or not data.projects_by_pk or not data.projects_by_pk.blob_artifacts:
            return []

        return data.projects_by_pk.blob_artifacts

    async def list_volumes(self, *, project_id: str) -> list[fragments.Volume]:
        """Get the volumes for a project."""
        query = queries.VolumesForProjectQuery(project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data = resp.get_data_or_raise()
        return data and data.volumes or []

    async def list_secrets(self, *, project_id: str) -> list[fragments.ProjectSecret]:
        """Get the secrets for a project."""
        query = ProjectSecretsQuery(project_id=project_id)
        resp = await self._client.make_graphql_request(query)
        data: ProjectSecretsResponse = resp.get_data_or_raise()
        return data and data.projects_by_pk and data.projects_by_pk.project_secrets or []

    """
    Follow API methods
    """

    async def follow_app_status(self, app_spec_id: str) -> AsyncGenerator[schemas.AppInstanceStatus, None]:
        """Follow the status of an app spec."""
        query = queries.AppSpecStatusSubscription(app_spec_id=app_spec_id)
        payload: base_graphql.GraphQLResponse[queries.AppInstancesWithStatusResponse]
        async for payload in self._client.make_graphql_subscription_request(query):
            app_instances = payload.data.app_instances
            if not app_instances:
                raise SlingshotException("App instances not found")
            yield schemas.AppInstanceStatus(app_instances[0].app_instance_status)

    async def follow_run_status(self, run_id: str) -> AsyncGenerator[schemas.JobStatus, None]:
        """Follow the status of a run."""
        query = queries.RunStatusSubscription(run_id=run_id)
        payload: base_graphql.GraphQLResponse[queries.RunsWithStatusResponse]
        async for payload in self._client.make_graphql_subscription_request(query):
            run = payload.data.run
            if not run:
                raise SlingshotException("Run not found")
            yield schemas.JobStatus(run.job_status)

    async def follow_deployment_status(self, app_spec_id: str) -> AsyncGenerator[schemas.AppInstanceStatus, None]:
        """Follow the status of a deployment."""
        query = queries.DeploymentStatusSubscription(app_spec_id=app_spec_id)
        payload: base_graphql.GraphQLResponse[queries.DeploymentInstancesWithStatusResponse]
        async for payload in self._client.make_graphql_subscription_request(query):
            deployment_instances = payload.data.deployment_instances
            if not deployment_instances:
                raise SlingshotException("Deployment instances not found")
            yield schemas.AppInstanceStatus(deployment_instances[0].deployment_instance_status)

    """
    Create API methods
    """

    async def create_project(
        self, project_id: str, project_display_name: Optional[str] = None
    ) -> schemas.Response[schemas.ProjectId]:
        """Create a new project."""
        return await self._client.make_request(
            url=f"project",
            method="post",
            response_model=schemas.Response[schemas.ProjectId],
            json_data=schemas.BodyNewProject(project_id=project_id, display_name=project_display_name).dict(),
        )

    async def create_app(
        self,
        name: str,
        command: str | None,
        app_type: schemas.AppType,
        app_sub_type: schemas.AppSubType | None,
        exec_env_spec_id: str,
        machine_size: schemas.MachineSize,
        mounts: list[schemas.MountSpecUnion],
        attach_project_credentials: bool,
        config_variables: JSONType | None = None,
        app_port: int | None = None,
        *,
        project_id: str,
    ) -> schemas.AppSpecIdResponse:
        """Create an app spec."""
        mount_requests = [parse_obj_as(schemas.MountRequestUnion, i) for i in mounts]
        body = schemas.CreateAppBody(
            name=name,
            command=command,
            app_type=app_type,
            app_sub_type=app_sub_type,
            exec_env_spec_id=exec_env_spec_id,
            machine_size=machine_size,
            mounts=mount_requests,
            attach_project_credentials=attach_project_credentials,
            config_variables=config_variables,
            app_port=app_port,
        ).dict()
        body["machine_size"] = machine_size.value

        return await self._client.make_request(
            url=f"project/{project_id}/apps", method="post", response_model=schemas.AppSpecIdResponse, json_data=body
        )

    async def create_or_update_environment_spec(
        self,
        name: str,
        requested_python_requirements: list[schemas.RequestedRequirement] | None = None,
        requested_apt_requirements: list[schemas.RequestedAptPackage] | None = None,
        gpu_drivers: bool = False,
        force_create_environment: bool = False,
        *,
        project_id: str,
    ) -> schemas.CreateEnvironmentSpecResponse:
        """Create or update an environment spec."""
        body = schemas.ExecutionEnvironmentSpecRequestBody(
            name=name,
            python_packages=requested_python_requirements or [],
            apt_packages=requested_apt_requirements or [],
            gpu_drivers=gpu_drivers,
            force_create_environment=force_create_environment,
        )
        return await self._client.make_request(
            url=f"project/{project_id}/environment",
            method="post",
            response_model=schemas.CreateEnvironmentSpecResponse,
            json_data=body.dict(),
        )

    async def create_volume(self, volume_name: str, *, project_id: str) -> schemas.Response[str]:
        """Create a volume."""
        return await self._client.make_request(
            url=f"project/{project_id}/volume",
            method="post",
            response_model=schemas.Response[str],
            json_data={"name": volume_name},
        )

    """
    Update API methods
    """

    async def update_app(
        self,
        app_spec_id: str,
        command: str | None,
        exec_env_spec_id: str,
        machine_size: MachineSize,
        mounts: list[schemas.MountSpecUnion],
        attach_project_credentials: bool,
        config_variables: JSONType | None = None,
        app_port: int | None = None,
        batch_size: int | None = None,
        batch_interval: int | None = None,
        *,
        name: str | None = None,
        project_id: str,
    ) -> schemas.Response[bool]:
        """Update an app spec."""
        mount_requests = [parse_obj_as(schemas.MountRequestUnion, i) for i in mounts]
        body = schemas.UpdateAppBody(
            command=command,
            name=name,
            exec_env_spec_id=exec_env_spec_id,
            machine_size=machine_size,
            mounts=mount_requests,
            config_variables=config_variables,
            attach_project_credentials=attach_project_credentials,
            app_port=app_port,
            batch_size=batch_size,
            batch_interval=batch_interval,
        ).dict()
        # TODO: Find a way for pydantic not to convert machine_size to an enum
        body["machine_size"] = machine_size.value
        return await self._client.make_request(
            url=f"project/{project_id}/apps/{app_spec_id}",
            method="post",
            response_model=schemas.Response[bool],
            json_data=body,
        )

    async def put_secret(self, secret_name: str, secret_value: str, *, project_id: str) -> schemas.PutResultResponse:
        """Update a secret by name."""
        return await self._client.make_request(
            url=f"project/{project_id}/secret/{secret_name}",
            method="put",
            response_model=schemas.PutResultResponse,
            json_data={"secret_value": secret_value},
        )

    """
    Delete API methods
    """

    async def delete_app(self, app_spec_id: str, project_id: str) -> schemas.ResponseOK:
        """Delete an app spec."""
        return await self._client.make_request(
            url=f"project/{project_id}/apps/{app_spec_id}", method="delete", response_model=schemas.ResponseOK
        )

    async def delete_environment_spec(
        self, execution_environment_spec_id: str, *, project_id: str | None = None
    ) -> None:
        """Delete an environment spec."""
        mutation = queries.ArchiveExecutionEnvironmentSpecMutation(
            execution_environment_spec_id=execution_environment_spec_id, is_archived=True
        )
        resp = await self._client.make_graphql_request(mutation)
        resp.get_data_or_raise()
        return None

    async def delete_volume(self, volume_name: str, *, project_id: str) -> schemas.Response[bool]:
        """Delete a volume."""
        return await self._client.make_request(
            url=f"project/{project_id}/volume/{volume_name}", method="delete", response_model=schemas.Response[bool]
        )

    async def delete_secret(self, secret_name: str, *, project_id: str) -> schemas.DeleteResultResponse:
        """Delete a secret by name."""
        return await self._client.make_request(
            url=f"project/{project_id}/secret/{secret_name}",
            method="delete",
            response_model=schemas.DeleteResultResponse,
        )

    """
    Start API methods
    """

    async def start_app(
        self,
        *,
        app_spec: fragments.AppSpec,
        machine_size: schemas.MachineSize | None = None,
        source_code_id: str | None = None,
        cmd: str | None = None,
        mount_specs: list[schemas.MountSpecUnion] | None = None,
        config_variables: JSONType | None = None,
        exec_env_id: str | None = None,
        attach_project_credentials: bool | None = None,
        app_port: int | None = None,
        project_id: str,
    ) -> schemas.HasAppInstanceIdResponse:
        """Start an app."""
        # TODO: Separate the SDK logic (using existing spec) from the API logic

        if app_spec.app_sub_type == schemas.AppSubType.SESSION:
            cmd = """bash -c "echo 'This command should never run!'; exit 1"""

        machine_size = machine_size or app_spec.machine_size
        command = cmd or app_spec.app_spec_command
        if mount_specs is None:
            mount_specs = [schemas.mount_spec_from_remote(mount_spec) for mount_spec in app_spec.mount_specs]
        if config_variables is None and app_spec.config_variables is not None:
            config_variables = json.loads(app_spec.config_variables)
        exec_env_id = exec_env_id or app_spec.execution_environment_spec.execution_environment.execution_environment_id
        attach_project_credentials = (
            attach_project_credentials if attach_project_credentials is not None else app_spec.service_account
        )
        app_port = app_port or app_spec.app_port
        mount_requests = [parse_obj_as(schemas.MountRequestUnion, i) for i in mount_specs]
        body = schemas.StartAppBody(
            machine_size=machine_size,
            source_code_id=source_code_id,
            cmd=command,
            mounts=mount_requests,
            config_variables=config_variables,
            exec_env_id=exec_env_id,
            attach_project_credentials=attach_project_credentials,
            app_port=app_port,
        )

        return await self._client.make_request(
            url=f"project/{project_id}/apps/{app_spec.app_spec_id}/start",
            method="post",
            response_model=schemas.HasAppInstanceIdResponse,
            json_data=body.dict(),
        )

    async def start_run(
        self,
        run_spec: fragments.AppSpec,
        source_code_id: str | None = None,
        machine_size: Optional[schemas.MachineSize] = None,
        hyperparameters: Hyperparameter | None = None,
        cmd: str | None = None,
        mount_specs: list[schemas.MountSpecUnion] | None = None,
        exec_env_id: str | None = None,
        attach_project_credentials: bool | None = None,
        debug_mode: bool = False,  # TODO: maybe don't expose this to the user
        *,
        project_id: str,
    ) -> schemas.RunCreateResponse:
        """Start a run."""
        machine_size = machine_size or run_spec.machine_size
        if not hyperparameters and run_spec.config_variables:
            hyperparameters = json.loads(run_spec.config_variables)
        command = cmd or run_spec.app_spec_command
        mount_specs = mount_specs or [
            gql_mount_spec_to_read_mount_spec(mount_spec) for mount_spec in run_spec.mount_specs
        ]
        exec_env_id = exec_env_id or run_spec.execution_environment_spec.execution_environment.execution_environment_id
        should_attach_project_credentials = (
            attach_project_credentials if attach_project_credentials is not None else run_spec.service_account
        )

        body = schemas.StartAppBody(
            source_code_id=source_code_id,
            machine_size=machine_size,
            config_variables=hyperparameters,
            cmd=command,
            mounts=mount_specs,  # TODO: Fix this
            exec_env_id=exec_env_id,
            attach_project_credentials=should_attach_project_credentials,
        )

        return await self._client.make_request(
            url=f"project/{project_id}/run/{run_spec.app_spec_id}/start",
            method="post",
            response_model=schemas.RunCreateResponse,
            params={"debug_mode": json.dumps(debug_mode)},
            json_data=body.dict(),
        )

    async def deploy_model(
        self,
        source_code_id: str,
        deployment_spec_id: str,
        machine_size: schemas.MachineSize | None = None,
        config_variables: JSONType | None = None,
        mount_specs: list[schemas.MountSpecUnion] | None = None,
        exec_env_id: str | None = None,
        cmd: str | None = None,
        *,
        project_id: str,
    ) -> schemas.DeploymentInstanceIdResponse:
        """Start a deployment."""
        deployment_spec = await self.get_app_spec_by_id(deployment_spec_id, project_id=project_id)
        if not deployment_spec:
            raise SlingshotException(f"Deployment not found: {deployment_spec_id}")
        machine_size = machine_size or deployment_spec.machine_size
        if not config_variables and deployment_spec.config_variables:
            config_variables = json.loads(deployment_spec.config_variables)
        mount_specs = mount_specs or [
            gql_mount_spec_to_read_mount_spec(mount_spec) for mount_spec in deployment_spec.mount_specs
        ]
        exec_env_id = (
            exec_env_id or deployment_spec.execution_environment_spec.execution_environment.execution_environment_id
        )
        cmd = cmd or deployment_spec.app_spec_command
        should_attach_project_credentials = deployment_spec.service_account

        body = schemas.StartAppBody(
            source_code_id=source_code_id,
            machine_size=machine_size,
            config_variables=config_variables,
            mounts=mount_specs,  # TODO: Fix this
            exec_env_id=exec_env_id,
            cmd=cmd,
            attach_project_credentials=should_attach_project_credentials,
        )
        return await self._client.make_request(
            url=f"project/{project_id}/deploy/{deployment_spec_id}/start",
            method="post",
            response_model=schemas.DeploymentInstanceIdResponse,
            json_data=body.dict(),
        )

    async def start_app_code_sync(self, app_spec_id: str, *, project_id: str) -> schemas.Response[schemas.SshPort]:
        """Start code sync for an app."""
        return await self._client.make_request(
            url=f"project/{project_id}/apps/{app_spec_id}/sync/start",
            method="post",
            response_model=schemas.Response[schemas.SshPort],
        )

    """
    Stop API methods
    """

    async def stop_app(self, app_spec_id: str, project_id: str) -> schemas.ResponseOK:
        """Stop an app."""
        return await self._client.make_request(
            url=f"project/{project_id}/apps/{app_spec_id}/stop", method="post", response_model=schemas.ResponseOK
        )

    async def cancel_run(self, run_id: str, *, project_id: str) -> schemas.ResponseOK:
        """Cancel a run."""
        return await self._client.make_request(
            url=f"project/{project_id}/run/{run_id}/cancel", method="post", response_model=schemas.ResponseOK
        )

    async def stop_deployment(self, deployment_spec_id: str, *, project_id: str) -> schemas.ResponseOK:
        """Stop a deployment."""
        return await self._client.make_request(
            url=f"project/{project_id}/deploy/{deployment_spec_id}/stop",
            method="post",
            response_model=schemas.ResponseOK,
        )

    """
    Logs API methods
    """

    async def get_app_logs(self, app_spec_id: str, project_id: str) -> schemas.LogsResponse:
        """Get logs for an app spec."""
        return await self._client.make_request(
            url=f"project/{project_id}/apps/{app_spec_id}/logs", method="get", response_model=schemas.LogsResponse
        )

    async def get_run_logs(self, run_id: str, *, project_id: str) -> schemas.LogsResponse:
        """Get logs for a run."""
        return await self._client.make_request(
            url=f"project/{project_id}/run/{run_id}/logs", method="get", response_model=schemas.LogsResponse
        )

    """
    Predict API methods
    """

    async def predict(
        self, project_id: str, deployment_name: str, example_bytes: bytes, timeout_seconds: int = 60
    ) -> schemas.PredictionResponse:
        """Make a prediction."""
        return await self._client.make_request(
            url=f"predict/{project_id}/{deployment_name}",
            method="post",
            response_model=schemas.PredictionResponse,
            timeout=datetime.timedelta(seconds=timeout_seconds),
            data={"example": example_bytes},
        )

    @backoff.on_exception(backoff.expo, (Retry,), max_tries=3)
    async def prompt_openai(
        self,
        request: schemas.PromptOpenAIBody,
        timeout: datetime.timedelta = datetime.timedelta(seconds=600),
        *,
        project_id: str,
    ) -> schemas.OpenAIResponse:
        """Make a prediction to an OpenAI model."""
        # TODO: Add idempotence key. If the client wants to cache, they can set idempotence_key to hash of prompt.
        try:
            return await self._client.make_request(
                url=f"project/{project_id}/prompt/openai",
                method="post",
                response_model=schemas.OpenAIResponse,
                json_data=request.dict(),
                timeout=timeout,
            )
        except SlingshotClientHttpException as e:
            if e.status == 429 or e.status == 503:
                raise Retry(e)
            raise e

    """
    Artifact API methods
    """

    async def signed_url_blob_artifact_many(
        self, blob_artifact_id: str, expiration: datetime.timedelta = datetime.timedelta(hours=1), *, project_id: str
    ) -> schemas.BlobArtifactSignedURLManyResponse:
        """Get a signed URL for an artifact."""
        params: dict[str, float | str] = {"expiration": expiration.total_seconds()}
        return await self._client.make_request(
            url=f"project/{project_id}/artifact/{blob_artifact_id}/signed_url_many",
            params=params,
            method="get",
            response_model=schemas.BlobArtifactSignedURLManyResponse,
        )

    async def signed_url_blob_artifact(
        self,
        blob_artifact_id: str,
        file_path: str | None = None,
        expiration: datetime.timedelta = datetime.timedelta(hours=1),
        *,
        project_id: str,
    ) -> schemas.BlobArtifactSignedURLResponse:
        """Get a signed URL for an artifact."""
        params: dict[str, float | str] = {"expiration": expiration.total_seconds()}
        if file_path:
            params["file_path"] = file_path
        return await self._client.make_request(
            url=f"project/{project_id}/artifact/{blob_artifact_id}/signed_url",
            params=params,
            method="get",
            response_model=schemas.BlobArtifactSignedURLResponse,
        )

    async def upsert_dataset_artifact(
        self, *, upsert_artifact_id: str, dataset_artifact_tag: str, project_id: str
    ) -> schemas.BlobArtifactIdResponse:
        """Apply an upsert to the latest dataset matching the given tag using an existing upsert artifact."""
        return await self._client.make_request(
            url=f"project/{project_id}/artifact/{upsert_artifact_id}/upsert/{dataset_artifact_tag}",
            method="post",
            response_model=schemas.BlobArtifactIdResponse,
        )

    async def upload_signed_url_blob_artifact(
        self,
        filename: str,
        as_zip: bool,  # Defaults to True if artifact_path is a directory
        blob_artifact_tag: str | None = None,
        *,
        project_id: str,
    ) -> schemas.BlobArtifactUploadSignedURLResponse:
        """Get a signed URL for uploading an artifact."""
        return await self._client.make_request(
            url=f"project/{project_id}/artifact/upload_signed_url",
            method="get",
            json_data=schemas.UploadBlobArtifactBody(
                filename=filename, tag=blob_artifact_tag, is_zipped_directory=as_zip
            ).dict(),
            response_model=schemas.BlobArtifactUploadSignedURLResponse,
        )

    """
    Source code API methods
    """

    async def upload_source_code(
        self, artifact_id: str, code_description: str | None = None, *, project_id: str
    ) -> schemas.Response[schemas.UploadedSourceCode]:
        """Upload source code."""
        params: ParamsType = {"blob_artifact_id": artifact_id}
        if code_description is not None:
            params["description"] = code_description
        return await self._client.make_request(
            url=f"project/{project_id}/source_code",
            method="post",
            response_model=schemas.Response[schemas.UploadedSourceCode],
            params=params,
        )

    """
    Environment API methods
    """

    async def rebuild_all_environments(self) -> str:
        """Trigger a rebuild of all environments."""
        return await self._client.make_request(url=f"admin/rebuild_all_environments", method="post", response_model=str)


async def _zip_dir(dir_path: Path) -> str:
    # TODO: Use a temporary directory instead of the current directory, Then delete the temporary directory.
    return shutil.make_archive("data", "zip", dir_path)


@contextlib.asynccontextmanager
async def _maybe_make_http_session(session: aiohttp.ClientSession | None) -> AsyncIterator[aiohttp.ClientSession]:
    if session is None or session.closed:
        async with aiohttp.ClientSession() as session:
            yield session
    else:
        yield session
