"""
Our autogenerated schemas aren't perfect. This file contains extensions to the autogenerated schemas for things it
doesn't handle well.
"""
from __future__ import annotations

__all__ = [
    "Response",
    "Hyperparameter",
    "RequestedRequirement",
    "RequestedAptPackage",
    "Project",
    "ProjectImpl",
    "HasRunId",
    "HasSourceCodeId",
    "HasBlobArtifactId",
    "HasExecutionEnvironmentId",
    "HasAppSpecId",
    "AppInstanceStatus",
    "JobStatus",
    "MountRequestUnion",
    "AuthTokenUnion",
    "ExecEnvStatus",
    "HasExecutionEnvironmentSpecId",
    "Example",
    "Data",
    "Result",
    "Annotation",
    "Prediction",
    "Upsert",
    "ExampleModification",
    "Dataset",
]

import re
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Generic, Literal, Optional, Protocol, TypeVar, Union

from pydantic import BaseModel, Field, validator
from pydantic.generics import GenericModel

from .generated import schemas

Hyperparameter = dict[str, Any]

DataT = TypeVar("DataT")


class Response(GenericModel, Generic[DataT]):
    data: Optional[DataT]
    error: Optional[schemas.SlingshotLogicalError]


class RequirementsParsingError(ValueError):
    pass


class RequestedRequirement(BaseModel):
    library: str
    version: Optional[str]
    pin: Optional[Literal["==", "@", ">=", "<=", ">", "<", "~="]]

    # Pin must be present iff version is present
    @validator("pin")
    def pin_must_be_present_iff_version_is_present(cls, v: str | None, values: dict[str, Any]) -> object:
        if v is None and values["version"] is not None:
            raise ValueError("pin must be present if version is present")
        if v is not None and values["version"] is None:
            raise ValueError("version must be present if pin is present")
        return v

    @classmethod
    def from_str(cls, line: str) -> RequestedRequirement:
        if not line:
            raise RequirementsParsingError("empty requirement")
        if line.startswith("--"):  # --index-url or --extra-index-url
            raise RequirementsParsingError(f"Unsupported requirement {line}")

        if line.startswith("-"):  # -r or -e or -c
            raise RequirementsParsingError(f"Unsupported requirement {line}")

        match = re.match(r"([^\s>=@<~]+) *(==|@|>=|<=|~=|>|<)? *(\S+)?", line)
        if match is None:
            raise RequirementsParsingError(f"Unsupported requirement {line}")
        library, pin, version = match.groups()
        try:
            return cls(library=library.strip(), version=version and version.strip(), pin=pin and pin.strip())
        except ValueError as e:
            raise RequirementsParsingError(f"Unsupported requirement {line}") from e

    def as_str(self) -> str:
        if self.pin is None or self.version is None:
            return self.library
        return f"{self.library}{self.pin}{self.version}"

    def __str__(self) -> str:
        return self.as_str()


class Project(Protocol):
    project_id: str
    display_name: str


class ProjectImpl(BaseModel):
    project_id: str
    display_name: str


class HasRunId(Protocol):
    run_id: str


class HasExecutionEnvironmentSpecId(Protocol):
    execution_environment_spec_id: str


class HasExecutionEnvironmentId(Protocol):
    execution_environment_id: str


class HasAppSpecId(Protocol):
    app_spec_id: str


class HasSourceCodeId(Protocol):
    source_code_id: str
    source_code_name: str


class HasBlobArtifactId(Protocol):
    blob_artifact_id: str
    blob_artifact_name: str


class RequestedAptPackage(BaseModel):
    name: str

    def __str__(self) -> str:
        return self.name


class AppInstanceStatus(str, Enum):
    STOPPED = "STOPPED"
    STARTING = "STARTING"
    READY = "READY"
    ERROR = "ERROR"


class JobStatus(str, Enum):
    NEW = "NEW"
    ACTIVE = "ACTIVE"
    SUCCESS = "SUCCESS"
    CANCELLING = "CANCELLING"
    CANCELLED = "CANCELLED"
    ERROR = "ERROR"


class ExecEnvStatus(str, Enum):
    COMPILING = "COMPILING"
    FAILED = "FAILED"
    READY = "READY"


MountRequestUnion = Union[
    schemas.UploadMountSpecRequest,
    schemas.VolumeMountSpecRequest,
    schemas.DownloadByNameMountSpecRequest,
    schemas.DownloadByTagMountSpecRequest,
]


class AuthTokenUnion(BaseModel):
    token: str
    user_id: Optional[str]
    service_account_id: Optional[str]

    @classmethod
    def from_auth_token(cls, auth_token: schemas.AuthToken) -> AuthTokenUnion:
        return cls(token=auth_token.token, user_id=auth_token.user_id, service_account_id=None)

    @classmethod
    def from_service_account_token(cls, service_account_token: schemas.ServiceAccountToken) -> AuthTokenUnion:
        return cls(
            token=service_account_token.token, user_id=None, service_account_id=service_account_token.service_account_id
        )

    @property
    def is_service_account(self) -> bool:
        return self.service_account_id is not None

    @property
    def is_user(self) -> bool:
        return self.user_id is not None

    @validator("service_account_id")
    def validate_xor(cls, v: str | None, values: dict[str, Any]) -> str | None:
        if v is None and values["user_id"] is None:
            raise ValueError("Both service_account_id and user_id cannot be None")
        if v is not None and values["user_id"] is not None:
            raise ValueError("Both service_account_id and user_id cannot be set")
        return v


# Note: these are copied from the backend schemas, but should ideally be generated from those.
class Data(BaseModel):
    """
    This can store any data. E.g. if you have a question and a context, you can store them concatenated in text,
    and then split them by key here to get the individual values.
    """

    image_url: Optional[str] = Field(
        None, alias="imageUrl", description="URL to an image, e.g. a signed URL, if relevant"
    )
    image_base64: Optional[str] = Field(None, alias="imageBase64")
    path: Optional[str] = Field(None, description="Path to the file within the dataset artifact directory, if relevant")
    text: Optional[str] = None

    class Config:
        extra = "allow"  # This allows fields not specified in the model
        allow_population_by_field_name = True  # This allows fields to be populated by their alias


class Result(BaseModel):
    """
    This is the "Y", i.e. the thing you want to predict.
    """

    result_id: str = Field(
        default_factory=lambda: uuid.uuid4().hex[:8], alias="resultId"
    )  # UUID generated by the end-user
    task: Optional[str] = Field(
        None,
        description="The task, for example 'question answering'. Useful when there are multiple results for a single "
        "annotation, e.g. if a single example has multiple questions.",
    )
    task_type: Union[Literal["classification"], str] = Field(
        ...,
        alias="taskType",
        description="The task type, for example classification. This can be one of the Slingshot-defined types or "
        "something custom",
    )
    #  # For classification, this is {[class:string]:boolean}
    value: dict[str, Any] = Field(
        ...,
        description="The content of the annotation result. For example, for classification, this would be a "
        "dictionary of class names to booleans",
    )

    class Config:
        allow_population_by_field_name = True  # This allows fields to be populated by their alias
        extra = "allow"  # E.g. for a model prediction, confidence can go here


class Annotation(BaseModel):
    annotation_id: str = Field(
        default_factory=lambda: uuid.uuid4().hex[:8], alias="annotationId"
    )  # UUID generated by the end-user
    result: list[Result] = Field(..., description="The annotation result")
    created_at: Optional[datetime] = Field(
        default_factory=datetime.utcnow, alias="createdAt", description="When the annotation was created"
    )
    updated_at: Optional[datetime] = Field(
        default_factory=datetime.utcnow, alias="updatedAt", description="When the annotation was last updated"
    )
    annotator: Optional[str] = Field(None, alias="annotator", description="The ID or name of the annotator")

    class Config:
        allow_population_by_field_name = True


class Prediction(Annotation):
    model: Optional[str] = Field(None, alias="model", description="The ID or name of the model")

    class Config:
        allow_population_by_field_name = True


class Example(BaseModel):
    example_id: str = Field(
        default_factory=lambda: uuid.uuid4().hex[:8], alias="exampleId"
    )  # UUID generated by the end-user
    created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, alias="createdAt")
    updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, alias="updatedAt")
    data: Data
    annotations: list[Annotation] = Field(default_factory=list)
    predictions: list[Prediction] = Field(default_factory=list)

    class Config:
        allow_population_by_field_name = True


class ExampleModification(BaseModel):
    example_id: str = Field(..., alias="exampleId")
    modified_data: Optional[Data] = Field(None, alias="modifiedData")
    new_annotations: list[Annotation] = Field(default_factory=list, alias="newAnnotations")
    new_predictions: list[Prediction] = Field(default_factory=list, alias="newPredictions")

    class Config:
        allow_population_by_field_name = True


class Upsert(BaseModel):
    """
    Upserts are used to update the dataset
    """

    upsert_id: Optional[str] = Field(None, alias="upsertId")  # UUID generated by the end-user
    new_examples: list[Example] = Field(default_factory=list, alias="newExamples")
    modified_examples: list[ExampleModification] = Field(default_factory=list, alias="modifiedExamples")
    updated_at: datetime = Field(
        default_factory=datetime.utcnow,
        alias="updatedAt",
        description="The time at which the upsert was created. Defaults to now.",
    )

    class Config:
        allow_population_by_field_name = True


class Dataset(BaseModel):
    """
    Dataset is a JSONL of examples
    """

    examples: list[Example]
