from __future__ import annotations

import asyncio
import json
import math
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any

import typer

from slingshot import schemas
from slingshot.cli.shared import prompt_confirm
from slingshot.schemas import slingshot_schema as slingshot_yaml_schemas
from slingshot.sdk.errors import SlingshotException
from slingshot.sdk.slingshot_api import JSONType
from slingshot.sdk.utils import yaml
from slingshot.shared.utils import pydantic_to_dict

if TYPE_CHECKING:
    from slingshot.sdk.slingshot_sdk import SlingshotSDK


def format_logline(logline: schemas.LogLine) -> str:
    return logline.log


def merge_loglines(loglines: list[schemas.LogLine]) -> str:
    return "\n".join(format_logline(logline) for logline in loglines)


async def follow_run_logs_until_done(sdk: SlingshotSDK, run_id: str) -> schemas.JobStatus:
    task_logs = asyncio.create_task(sdk.print_logs(run_id=run_id, follow=True))
    task_status = asyncio.create_task(_wait_for_run_finished(sdk=sdk, run_id=run_id))
    done, pending = await asyncio.wait([task_logs, task_status], return_when=asyncio.FIRST_COMPLETED)
    for task in pending:
        task.cancel()

    # If task_status was cancelled or raised an exception, handle it
    if task_status.cancelled():
        raise SlingshotException("Something went wrong following logs")
    elif (exception := task_status.exception()) is not None:
        raise exception

    return task_status.result()


async def follow_app_logs_until_ready(sdk: SlingshotSDK, app_spec_id: str) -> schemas.AppInstanceStatus:
    task_logs = asyncio.create_task(sdk.print_logs(app_spec_id=app_spec_id, follow=True))
    task_status = asyncio.create_task(_wait_for_app_ready(sdk=sdk, app_spec_id=app_spec_id))
    done, pending = await asyncio.wait([task_logs, task_status], return_when=asyncio.FIRST_COMPLETED)
    for task in pending:
        task.cancel()

    # If task_status was cancelled or raised an exception, handle it
    if task_status.cancelled():
        raise SlingshotException("Something went wrong following logs")
    elif (exception := task_status.exception()) is not None:
        raise exception

    return task_status.result()


def datetime_to_human_readable(dt: datetime) -> str:
    """Assumes UTC datetime and converts to local time in the format of: Mar 2, 2021 1:30AM EST"""
    date_str = dt.strftime('%b %d, %Y %I:%M:%S %p')
    return f'{date_str} UTC'


def seconds_to_human_readable(seconds: float) -> str:
    if seconds < 60:
        return f"{seconds:.0f}s"
    if seconds < 3600:
        # In the format 10m 30s
        return f"{int(seconds // 60)}m {seconds_to_human_readable(seconds % 60)}"
    if seconds < 86400:
        return f"{int(seconds // 3600)}h {seconds_to_human_readable(seconds % 3600)}"
    return f"{seconds // 86400}d {seconds_to_human_readable(seconds % 86400)}"


def bytes_to_human_readable_size(size: int | None, precision: int = 1) -> str:
    if size is None:
        return ""

    if size == 0:
        return "0 Bytes"
    k = 1024
    sizes = ["Bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"]

    i = int(math.log(size) // math.log(k))
    return f"{size / k ** i:.{precision}f} {sizes[i]}"


def parse_extra_args(extra_args: list[str]) -> JSONType:
    # We prepend an empty string to detect loose args
    # We append -- to parse args correctly when the last arg is a flag
    extra_args = [""] + extra_args + ["--"]
    extra_args_pairs = list(zip(extra_args, extra_args[1:]))
    # Checks if there are loose args and raises an error if there are
    loose_args = [
        next_arg for arg, next_arg in extra_args_pairs if not next_arg.startswith("--") and not arg.startswith("--")
    ]
    if loose_args:
        raise typer.BadParameter(f"Extra args must start with -- . Got loose values {loose_args}")
    # All args have to start with -- so we can iterate only on `.startswith("--")`
    # If the arg is a key-value arg, we add it by using the first position of the tuple as key and second as value
    # If the arg is a flag, then the second position of the tuple will start with --. We instead set the value to True
    return {k[2:]: (True if v.startswith("--") else _infer_type(v)) for k, v in extra_args_pairs if k.startswith("--")}


async def prompt_push_code(sdk: SlingshotSDK) -> str | None:
    """Prompts the user to push code if it has changed since last push."""
    if not await sdk.has_code_changed() or not prompt_confirm(
        "Code has changed since last push. Do you want to push now?", default=True
    ):
        return None
    source_code = await sdk.push_code(".", and_print=True)
    source_code_id = source_code.source_code_id
    return source_code_id


async def get_hyperparameter_config_from_file(config_file: Path) -> JSONType:
    if not config_file.is_file():
        raise SlingshotException(
            f"Config file {config_file.name} could not be found or is not a file ({config_file.absolute()})"
        )
    with open(config_file, "r") as f:
        return json.load(f)


def _infer_type(value: str) -> Any:
    # noinspection GrazieInspection
    """Tries to infer the type of a string. If it can't, just returns the string."""
    try:
        return json.loads(value)
    except ValueError:
        return value


async def _wait_for_run_finished(sdk: SlingshotSDK, run_id: str) -> schemas.JobStatus:
    async for status in sdk.api.follow_run_status(run_id):
        if (
            status == schemas.JobStatus.CANCELLED
            or status == schemas.JobStatus.ERROR
            or status == schemas.JobStatus.SUCCESS
        ):
            return status
    raise AssertionError("Unreachable")


async def _wait_for_app_ready(sdk: SlingshotSDK, app_spec_id: str) -> schemas.AppInstanceStatus:
    async for status in sdk.api.follow_app_status(app_spec_id):
        if (
            status == schemas.AppInstanceStatus.READY
            or status == schemas.AppInstanceStatus.ERROR
            or status == schemas.AppInstanceStatus.STOPPED
        ):
            return status
    raise AssertionError("Unreachable")


def create_empty_project_manifest(manifest_path: Path) -> None:
    # Touch the file
    manifest_path.touch()

    # Insert an empty project manifest
    doc = pydantic_to_dict(slingshot_yaml_schemas.ProjectManifest(), exclude_unset=False)
    yaml.indent(mapping=2, sequence=4, offset=2)
    with manifest_path.open("w") as f:
        yaml.dump(doc, f)
