from contextlib import contextmanager
from datetime import datetime
from enum import Enum
import time
from typing import Any, Callable, Dict, List, Optional, Set

from click import ClickException
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import (
    BarColumn,
    Progress,
    TaskID,
    TextColumn,
    TimeElapsedColumn,
)

from anyscale.cli_logger import LogsLogger
from anyscale.client.openapi_client.models.cloud_providers import CloudProviders
from anyscale.client.openapi_client.models.create_experimental_workspace import (
    CreateExperimentalWorkspace,
)
from anyscale.client.openapi_client.models.session_state import SessionState
from anyscale.cluster_env import get_default_cluster_env_build
from anyscale.controllers.base_controller import BaseController
from anyscale.project import get_default_project
from anyscale.sdk.anyscale_client.models.compute_node_type import ComputeNodeType
from anyscale.sdk.anyscale_client.models.compute_template_query import (
    ComputeTemplateQuery,
)
from anyscale.sdk.anyscale_client.models.create_cluster_compute import (
    CreateClusterCompute,
)
from anyscale.sdk.anyscale_client.models.create_cluster_compute_config import (
    CreateClusterComputeConfig,
)
from anyscale.util import confirm, get_endpoint


POLL_INTERVAL_SECONDS = 10
WORKSPACE_VERIFICATION_TIMEOUT_MINUTES = 10

# default values for cluster compute config
MAXIMUM_UPTIME_MINUTES = 15
IDLE_TERMINATION_MINUTES = 5
HEAD_NODE_TYPE_AWS = "m5.xlarge"
HEAD_NODE_TYPE_GCP = "n1-standard-2"


class CloudFunctionalVerificationType(str, Enum):
    WORKSPACE = "WORKSPACE"
    # TODO (congding): add service


class CloudFunctionalVerificationTask:
    def __init__(self, task_id: TaskID, description: str):
        self.task_id = task_id
        self.description = description
        self.succeeded = False
        self.completed = False

    def update(self, succeeded: bool, description: str, completed: bool = False):
        self.description = description
        self.succeeded = succeeded
        self.completed = completed


class CloudFunctionalVerificationController(BaseController):
    def __init__(
        self, log: Optional[LogsLogger] = None, initialize_auth_api_client: bool = True,
    ):
        if log is None:
            log = LogsLogger()

        super().__init__(initialize_auth_api_client=initialize_auth_api_client)
        self.log = log

        # Used for rich live console
        self.step_progress: Dict[CloudFunctionalVerificationType, Progress] = {}
        self.overall_progress: Dict[CloudFunctionalVerificationType, Progress] = {}
        self.task_ids: Dict[CloudFunctionalVerificationType, TaskID] = {}

    @staticmethod
    def get_head_node_type(cloud_provider: CloudProviders) -> str:
        """
        Get the default head node type for the given cloud provider.
        """
        if cloud_provider == CloudProviders.AWS:
            return HEAD_NODE_TYPE_AWS
        elif cloud_provider == CloudProviders.GCP:
            return HEAD_NODE_TYPE_GCP
        raise ClickException(f"Unsupported cloud provider: {cloud_provider}")

    def get_or_create_cluster_compute(
        self, cloud_id: str, cloud_provider: CloudProviders
    ) -> str:
        """
        Get or create a cluster compute for cloud functional verification
        """
        cluster_compute_name = f"functional_verification_{cloud_id}"

        cluster_computes = self.api_client.search_compute_templates_api_v2_compute_templates_search_post(
            ComputeTemplateQuery(
                orgwide=True,
                name={"equals": cluster_compute_name},
                include_anonymous=True,
            )
        ).results
        if len(cluster_computes) > 0:
            return cluster_computes[0].id

        head_node_instance_type = self.get_head_node_type(cloud_provider)
        # no cluster compute found, create one
        cluster_compute_config = CreateClusterComputeConfig(
            cloud_id=cloud_id,
            max_workers=0,
            allowed_azs=["any"],
            head_node_type=ComputeNodeType(
                name="head_node_type", instance_type=head_node_instance_type,
            ),
            maximum_uptime_minutes=MAXIMUM_UPTIME_MINUTES,
            idle_termination_minutes=IDLE_TERMINATION_MINUTES,
            worker_node_types=[],
        )
        if cloud_provider == CloudProviders.AWS:
            cluster_compute_config.aws_advanced_configurations_json = {
                "TagSpecifications": [
                    {
                        "ResourceType": "instance",
                        "Tags": [
                            {"Key": "cloud_functional_verification", "Value": cloud_id,}
                        ],
                    }
                ]
            }
        elif cloud_provider == CloudProviders.GCP:
            cluster_compute_config.gcp_advanced_configurations_json = {
                "instance_properties": {
                    "labels": {"cloud_functional_verification": cloud_id},
                }
            }

        cluster_compute = self.anyscale_api_client.create_cluster_compute(
            CreateClusterCompute(
                name=cluster_compute_name,
                config=cluster_compute_config,
                anonymous=True,
            )
        ).result
        return cluster_compute.id

    def _prepare_verification(self, cloud_id: str, cloud_provider: CloudProviders):
        """
        Generate the required parameters for cloud functional verification.
        """
        cluster_compute_id = self.get_or_create_cluster_compute(
            cloud_id, cloud_provider
        )

        cluster_env_build_id = get_default_cluster_env_build(
            self.api_client, self.anyscale_api_client
        ).id

        project_id = get_default_project(
            self.api_client, self.anyscale_api_client, parent_cloud_id=cloud_id
        ).id

        return cluster_compute_id, cluster_env_build_id, project_id

    @contextmanager
    def _create_task(self, function: CloudFunctionalVerificationType, description: str):
        """
        Create a task on the console for cloud functional verification
        """
        task_id = self.step_progress[function].add_task(description)
        task = CloudFunctionalVerificationTask(task_id, description)
        try:
            yield task
        finally:
            self._update_console(
                task.succeeded, function, task_id, task.description, task.completed
            )

    def create_workspace(self, cloud_id: str, cloud_provider: CloudProviders):
        """
        Create a workspace for cloud functional verification
        """
        (
            cluster_compute_id,
            cluster_env_build_id,
            project_id,
        ) = self._prepare_verification(cloud_id, cloud_provider)

        create_workspace_arg = CreateExperimentalWorkspace(
            name=f"fxnvrf_{cloud_id}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}",
            description=f"workspace for cloud {cloud_id} functional verification",
            project_id=project_id,
            cloud_id=cloud_id,
            compute_config_id=cluster_compute_id,
            cluster_environment_build_id=cluster_env_build_id,
            idle_timeout_minutes=IDLE_TERMINATION_MINUTES,
        )

        workspace = self.api_client.create_workspace_api_v2_experimental_workspaces_post(
            create_workspace_arg
        ).result

        return workspace

    def verify_workspace(self, cloud_id: str, cloud_provider: CloudProviders) -> bool:
        """
        Verifies that the workspace is setup correctly on the given cloud.
        """
        # Create workspace
        with self._create_task(
            CloudFunctionalVerificationType.WORKSPACE, "Creating workspace..."
        ) as create_workspace_task:
            try:
                workspace = self.create_workspace(cloud_id, cloud_provider)
            except ClickException as e:
                create_workspace_task.update(
                    False, f"[bold red]Failed to create workspace: {e}"
                )
                return False
            url = get_endpoint(f"/workspaces/{workspace.id}/{workspace.cluster_id}")
            create_workspace_task.update(
                True, f"[bold green]Workspace created at {url}"
            )

        # Wait until workspace is active
        def get_workspace_status(workspace_id):
            return self.api_client.get_workspace_api_v2_experimental_workspaces_workspace_id_get(
                workspace_id
            ).result.state

        allowed_status_set = {
            SessionState.RUNNING,
            SessionState.STARTINGUP,
            SessionState.AWAITINGSTARTUP,
        }

        with self._create_task(
            CloudFunctionalVerificationType.WORKSPACE,
            "Waiting for workspace to become active...",
        ) as wait_task:
            try:
                self.poll_until_active(
                    CloudFunctionalVerificationType.WORKSPACE,
                    workspace.id,
                    get_workspace_status,
                    SessionState.RUNNING,
                    allowed_status_set,
                    wait_task,
                    WORKSPACE_VERIFICATION_TIMEOUT_MINUTES,
                )
            except ClickException as e:
                wait_task.update(
                    False,
                    f"[bold red]Error: {e}. Please click on the URL above to check the logs.",
                )
                return False
            wait_task.update(True, "[bold green]Workspace is active.")

        # terminate workspace
        with self._create_task(
            CloudFunctionalVerificationType.WORKSPACE, "Terminating workspace..."
        ) as terminate_workspace_task:
            try:
                # terminate the cluster leads to workspace termination
                self.anyscale_api_client.terminate_cluster(workspace.cluster_id, {})
            except ClickException as e:
                terminate_workspace_task.update(
                    False,
                    f"[bold red]Failed to terminate workspace: {e}",
                    completed=True,
                )
                return False
            terminate_workspace_task.update(
                True, "[bold green]Workspace termination initiated.", completed=True
            )
        return True

    def poll_until_active(  # noqa: PLR0913
        self,
        function: CloudFunctionalVerificationType,
        function_id: str,
        get_current_status: Callable[[str], Any],
        goal_status: Any,
        allowed_status_set: Set[Any],
        wait_task: CloudFunctionalVerificationTask,
        timeout_minutes: int,
    ) -> bool:
        """
        Polling until it is active.
        """
        start_time = time.time()
        end_time = start_time + timeout_minutes * 60
        while time.time() < end_time:
            time.sleep(POLL_INTERVAL_SECONDS)
            try:
                current_status = get_current_status(function_id)
            except ClickException as e:
                raise ClickException(
                    f"Failed to get {function.lower()} status: {e}"
                ) from None
            self._update_task_in_step_progress(
                function,
                wait_task.task_id,
                f"{wait_task.description} [{time.strftime('%H:%M:%S', time.localtime())}] Current status: {current_status}",
            )
            if current_status == goal_status:
                return True
            if current_status not in allowed_status_set:
                raise ClickException(
                    f"{function.capitalize()} is in an unexpected state: {current_status}"
                )
        raise ClickException(
            f"Timed out waiting for {function.lower()} to become active"
        )

    def verify(
        self,
        function: CloudFunctionalVerificationType,
        cloud_id: str,
        cloud_provider: CloudProviders,
    ) -> bool:
        """
        Kick off a single functional verification given the function type.
        """
        if function == CloudFunctionalVerificationType.WORKSPACE:
            return self.verify_workspace(cloud_id, cloud_provider)
        return False

    def _update_console(
        self,
        succeeded: bool,
        function: CloudFunctionalVerificationType,
        task_id: TaskID,
        description: str,
        completed: bool = False,
    ):
        """
        Update the console based on the verification result
        """
        self._update_overall_progress(succeeded, function, completed)
        self._finish_task_in_step_progress(function, task_id, description)

    def _update_task_in_step_progress(
        self,
        function: CloudFunctionalVerificationType,
        task_id: TaskID,
        description: str,
    ) -> None:
        """
        Update the task description in step progress
        """
        self.step_progress[function].update(task_id, description=description)

    def _finish_task_in_step_progress(
        self,
        function: CloudFunctionalVerificationType,
        task_id: TaskID,
        description: str,
    ) -> None:
        """
        Finish a task in step progress and update the description
        """
        self.step_progress[function].stop_task(task_id)
        self.step_progress[function].update(task_id, description=description)

    def _update_overall_progress(
        self,
        verification_result: bool,
        function: CloudFunctionalVerificationType,
        completed: bool = False,
    ) -> None:
        """
        Update overall progress based on the verification result
        """
        if verification_result:
            self.overall_progress[function].advance(self.task_ids[function], 1)
            if completed:
                self.overall_progress[function].update(
                    self.task_ids[function],
                    description=f"[bold green]{function.capitalize()} verification succeeded!",
                )
        else:
            self.overall_progress[function].stop_task(self.task_ids[function])
            self.overall_progress[function].update(
                self.task_ids[function],
                description=f"[bold red]{function.capitalize()} verification failed.",
            )

    def get_live_console(
        self, functions_to_verify: List[CloudFunctionalVerificationType]
    ) -> Live:
        """
        Get a live console for cloud functional verification.

        Each functional verification contains two progress bars:
        1. step_progress: progress for each functional verification step
        2. overall_progress: progress for overall functional verification
        """
        progress_group = []
        steps = {
            CloudFunctionalVerificationType.WORKSPACE: 3,
            # TODO (congding): add service
        }
        for function in functions_to_verify:
            step_progress = Progress(
                TimeElapsedColumn(), TextColumn("{task.description}"),
            )
            self.step_progress[function] = step_progress
            progress_group.append(
                Panel(step_progress, title=f"{function.lower()} verification")
            )
            overall_progress = Progress(
                TimeElapsedColumn(), BarColumn(), TextColumn("{task.description}")
            )
            self.overall_progress[function] = overall_progress
            progress_group.append(overall_progress)
            task_id = overall_progress.add_task("", total=steps[function])
            overall_progress.update(
                task_id, description=f"[bold #AAAAAA]Verifying {function.lower()}..."
            )
            self.task_ids[function] = task_id
        return Live(Group(*progress_group))

    def start_verification(
        self,
        cloud_id: str,
        cloud_provider: CloudProviders,
        functions_to_verify: List[CloudFunctionalVerificationType],
        yes: bool = False,
    ) -> bool:
        """
        Starts cloud functional verification
        """
        self.log.info("Start functional verification...")
        confirm(
            f"Functional verification for {', '.join(functions_to_verify)} is about to begin.\n"
            f"It will spin up one {self.get_head_node_type(cloud_provider)} instance and will incur a small amount of cost.\n"
            "For workspace verification, it takes about 5 minutes.\n"
            "The instances will be terminated after verification. Do you want to continue?",
            yes,
        )

        verification_results: List[bool] = []
        with self.get_live_console(functions_to_verify):
            for function in functions_to_verify:
                verification_results.append(
                    self.verify(function, cloud_id, cloud_provider)
                )
        return all(verification_results)
