from __future__ import annotations
from typing import TYPE_CHECKING, Any
import torch

if TYPE_CHECKING: from .engine import Engine

class Callback:
    """
    回调基类。
    所有方法都接收 engine 实例，允许修改 engine 状态或读取数据。
    """
    def on_init(self, engine: Engine): ...
    
    def on_train_start(self, engine: Engine): ...
    def on_train_end(self, engine: Engine): ...
    
    def on_epoch_start(self, engine: Engine): ...
    def on_epoch_end(self, engine: Engine): ...
    
    def on_batch_start(self, engine: Engine): ...
    def on_batch_end(self, engine: Engine): ...
    
    def on_eval_start(self, engine: Engine): ...
    def on_eval_end(self, engine: Engine): ...

    def on_requested_stop(self, engine: Engine): ...
    def on_exception(self, engine: Engine): ...

class Forward:
    '''自定义前向传播和 Loss 计算接口。

    实现此接口以接管 Engine 的默认前向传播逻辑。
    '''

    def forward(self, engine: Engine, data: Any, target: Any) -> torch.Tensor:
        '''执行前向传播并返回 Loss。

        Args:
            engine (Engine): 当前 Engine 实例。可以通过 engine.model 访问模型，通过 engine.criterion 访问损失函数。
            data (Any): 当前 Batch 的输入数据。
            target (Any): 当前 Batch 的目标数据（标签）。

        Returns:
            torch.Tensor: 计算得到的 Loss 标量。Engine 将使用此 Loss 进行反向传播。
        '''
        ... # Returns loss
