from typing import Optional, Union
from enum import Enum

from vocode.models.message import BaseMessage
from .model import TypedModel, BaseModel

FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS = 0.5
LLM_AGENT_DEFAULT_TEMPERATURE = 1.0
LLM_AGENT_DEFAULT_MAX_TOKENS = 256
LLM_AGENT_DEFAULT_MODEL_NAME = "text-curie-001"


class AgentType(str, Enum):
    BASE = "agent_base"
    LLM = "agent_llm"
    CHAT_GPT_ALPHA = "agent_chat_gpt_alpha"
    CHAT_GPT = "agent_chat_gpt"
    ECHO = "agent_echo"
    INFORMATION_RETRIEVAL = "agent_information_retrieval"
    RESTFUL_USER_IMPLEMENTED = "agent_restful_user_implemented"
    WEBSOCKET_USER_IMPLEMENTED = "agent_websocket_user_implemented"


class FillerAudioConfig(BaseModel):
    silence_threshold_seconds: float = FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS


class AgentConfig(TypedModel, type=AgentType.BASE):
    initial_message: Optional[BaseMessage] = None
    generate_responses: bool = True
    allowed_idle_time_seconds: Optional[float] = None
    end_conversation_on_goodbye: bool = False
    send_filler_audio: Union[bool, FillerAudioConfig] = False


class LLMAgentConfig(AgentConfig, type=AgentType.LLM):
    prompt_preamble: str
    expected_first_prompt: Optional[str] = None
    model_name: str = LLM_AGENT_DEFAULT_MODEL_NAME
    temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
    max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS


class ChatGPTAlphaAgentConfig(AgentConfig, type=AgentType.CHAT_GPT_ALPHA):
    prompt_preamble: str
    expected_first_prompt: Optional[str] = None
    temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
    max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS


class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT):
    prompt_preamble: str
    expected_first_prompt: Optional[str] = None
    generate_responses: bool = False
    temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
    max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS


class InformationRetrievalAgentConfig(
    AgentConfig, type=AgentType.INFORMATION_RETRIEVAL
):
    recipient_descriptor: str
    caller_descriptor: str
    goal_description: str
    fields: list[str]
    # TODO: add fields for IVR, voicemail


class EchoAgentConfig(AgentConfig, type=AgentType.ECHO):
    pass


class RESTfulUserImplementedAgentConfig(
    AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED
):
    class EndpointConfig(BaseModel):
        url: str
        method: str = "POST"

    respond: EndpointConfig
    generate_responses: bool = False
    # generate_response: Optional[EndpointConfig]
    # update_last_bot_message_on_cut_off: Optional[EndpointConfig]


class RESTfulAgentInput(BaseModel):
    conversation_id: str
    human_input: str


class RESTfulAgentOutputType(str, Enum):
    BASE = "restful_agent_base"
    TEXT = "restful_agent_text"
    END = "restful_agent_end"


class RESTfulAgentOutput(TypedModel, type=RESTfulAgentOutputType.BASE):
    pass


class RESTfulAgentText(RESTfulAgentOutput, type=RESTfulAgentOutputType.TEXT):
    response: str


class RESTfulAgentEnd(RESTfulAgentOutput, type=RESTfulAgentOutputType.END):
    pass


class WebSocketUserImplementedAgentConfig(
    AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED
):
    class RouteConfig(BaseModel):
        url: str

    respond: RouteConfig
    generate_responses: bool = False
    # generate_response: Optional[RouteConfig]
    # send_message_on_cut_off: bool = False


class WebSocketAgentMessageType(str, Enum):
    BASE = "websocket_agent_base"
    START = "websocket_agent_start"
    TEXT = "websocket_agent_text"
    READY = "websocket_agent_ready"
    STOP = "websocket_agent_stop"


class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE):
    conversation_id: Optional[str] = None


class WebSocketAgentTextMessage(
    WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT
):
    class Payload(BaseModel):
        text: str

    data: Payload

    @classmethod
    def from_text(cls, text: str, conversation_id: Optional[str] = None):
        return cls(data=cls.Payload(text=text), conversation_id=conversation_id)


class WebSocketAgentStartMessage(
    WebSocketAgentMessage, type=WebSocketAgentMessageType.START
):
    pass


class WebSocketAgentReadyMessage(
    WebSocketAgentMessage, type=WebSocketAgentMessageType.READY
):
    pass


class WebSocketAgentStopMessage(
    WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP
):
    pass
