from __future__ import annotations

from typing import Optional

from pydantic import BaseModel, Field

from slingshot import schemas

from ..base_graphql import BaseGraphQLQuery
from ..fragments import Run


class RunsForProjectResponse(BaseModel):
    runs: list[Run] = Field(..., alias="trainingRuns")


class RunsForProjectQuery(BaseGraphQLQuery[RunsForProjectResponse]):
    _query = """
        query RunsForProject($projectId: String!) {
            trainingRuns(where: {project: {projectId: {_eq: $projectId}}}, orderBy: {createdAt: DESC}) {
                ...Run
            }
        } """

    _depends_on = [Run]

    def __init__(self, project_id: str):
        super().__init__(variables={"projectId": project_id}, response_model=RunsForProjectResponse)


class RunByNameForProjectQuery(BaseGraphQLQuery[RunsForProjectResponse]):
    _query = """
        query RunByNameForProject($projectId: String!, $runName: String!) {
            trainingRuns(where: {project: {projectId: {_eq: $projectId}}, trainingRunName: {_eq: $runName}}) {
                ...Run
            }
        } """

    _depends_on = [Run]

    def __init__(self, project_id: str, run_name: str):
        super().__init__(
            variables={"projectId": project_id, "runName": run_name}, response_model=RunsForProjectResponse
        )


class RunByIdResponse(BaseModel):
    run: Optional[Run] = Field(..., alias="trainingRunsByPk")


class RunByIdQuery(BaseGraphQLQuery[RunByIdResponse]):
    _query = """
        query RunById($runId: String!) {
          trainingRunsByPk(trainingRunId: $runId) {
            ...Run
          }
        } """

    _depends_on = [Run]

    def __init__(self, run_id: str) -> None:
        super().__init__(variables={"runId": run_id}, response_model=RunByIdResponse)


class RunWithStatus(BaseModel):
    job_status: schemas.JobStatus = Field(..., alias="jobStatus")


class RunsWithStatusResponse(BaseModel):
    run: Optional[RunWithStatus] = Field(..., alias="trainingRunsByPk")


class RunStatusSubscription(BaseGraphQLQuery[RunsWithStatusResponse]):
    _query = """
        subscription RunStatusSubscription($runId: String!) {
          trainingRunsByPk(trainingRunId: $runId) {
            jobStatus
          }
        }
    """
    _depends_on = []

    def __init__(self, run_id: str):
        super().__init__(variables={"runId": run_id}, response_model=RunsWithStatusResponse)
