from __future__ import annotations
import logging
import os

from types import ModuleType
import typing as t
from typing import TYPE_CHECKING

import attr

import bentoml
from bentoml import Tag
from bentoml._internal.models.model import MODEL_YAML_FILENAME
from bentoml.exceptions import NotFound
from bentoml.models import ModelContext
from bentoml.models import ModelOptions

if TYPE_CHECKING:
    from bentoml.types import ModelSignatureDict
    from bentoml.types import ModelSignature


try:
    import FRAMEWORK_PY
except ImportError:  # pragma: no cover (trivial error checking)
    raise MissingDependencyException(
        "`FRAMEWORK_PY` is required in order to use module `bentoml.lightgbm`, install "
        "FRAMEWORK_PY with `pip install FRAMEWORK_PY`. For more information, refer to "
        "<FRAMEWORK INSTALL URL>"
    )


MODULE_NAME = "MY_PKG.MY_MODULE"
MODEL_FILENAME = "saved_model.EXT"
API_VERSION = "v1"

logger = logging.getLogger(__name__)


def get(tag_like: str | Tag) -> bentoml.Model:
    model = bentoml.models.get(tag_like)
    if model.info.module not in (MODULE_NAME, __name__):
        raise NotFound(
            f"Model {model.tag} was saved with module {model.info.module}, not loading with {MODULE_NAME}."
        )
    return model


@attr.define(frozen=True)
class FrameworkOptions(ModelOptions):
    pass


def load_model(
    bento_model: str | Tag | bentoml.Model,
    # *,
    # ...
) -> FrameworkModelType:
    """
    Load the <FRAMEWORK> model with the given tag from the local BentoML model store.

    Args:
        bento_model (``str`` ``|`` :obj:`~bentoml.Tag` ``|`` :obj:`~bentoml.Model`):
            Either the tag of the model to get from the store, or a BentoML `~bentoml.Model`
            instance to load the model from.
        ...
    Returns:
        <MODELTYPE>:
            The <FRAMEWORK> model loaded from the model store or BentoML :obj:`~bentoml.Model`.
    Example:
    .. code-block:: python
        import bentoml
        <LOAD EXAMPLE>
    """  # noqa
    if not isinstance(bento_model, bentoml.Model):
        bento_model = get(bento_model)

    if bento_model.info.module not in (MODULE_NAME, __name__):
        raise NotFound(
            f"Model {bento_model.tag} was saved with module {bento_model.info.module}, not loading with {MODULE_NAME}."
        )

    FRAMEWORK_PY.load(bento_model.path_of(MODEL_FILENAME))


def save_model(
    name: str,
    model: FrameworkModelType,
    *,
    signatures: dict[str, ModelSignatureDict | ModelSignature] | None = None,
    labels: dict[str, str] | None = None,
    custom_objects: dict[str, t.Any] | None = None,
    external_modules: list[ModuleType] | None = None,
    metadata: dict[str, t.Any] | None = None,
    # ...
) -> bentoml.Model:
    """
    Save a <FRAMEWORK> model instance to the BentoML model store.

    Args:
        name (``str``):
            The name to give to the model in the BentoML store. This must be a valid
            :obj:`~bentoml.Tag` name.
        model (<MODELTYPE>):
            The <FRAMEWORK> model to be saved.
        signatures (``dict[str, ModelSignatureDict]``, optional):
            Signatures of predict methods to be used. If not provided, the signatures default to
            <DEFAULT HERE>. See :obj:`~bentoml.types.ModelSignature` for more details.
        labels (``dict[str, str]``, optional):
            A default set of management labels to be associated with the model. An example is
            ``{"training-set": "data-1"}``.
        custom_objects (``dict[str, Any]``, optional):
            Custom objects to be saved with the model. An example is
            ``{"my-normalizer": normalizer}``.

            Custom objects are currently serialized with cloudpickle, but this implementation is
            subject to change.
        external_modules (:code:`List[ModuleType]`, `optional`, default to :code:`None`):
            user-defined additional python modules to be saved alongside the model or custom objects,
            e.g. a tokenizer module, preprocessor module, model configuration module
        metadata (``dict[str, Any]``, optional):
            Metadata to be associated with the model. An example is ``{"bias": 4}``.

            Metadata is intended for display in a model management UI and therefore must be a
            default Python type, such as ``str`` or ``int``.

        ...
    Returns:
        :obj:`~bentoml.Tag`: A tag that can be used to access the saved model from the BentoML model
        store.
    Example:
    .. code-block:: python
        <SAVE EXAMPLE>
    """
    context = ModelContext(
        framework_name="FRAMEWORK_PY",
        framework_versions={"FRAMEWORK_PY": FRAMEWORK_PY.__version__},
    )

    if signatures is None:
        signatures = {
            DEFAULT_MODEL_METHOD: {"batchable": False},
        }
        logger.info(
            f"Using the default model signature for <FRAMEWORK> ({signatures}) for model {name}."
        )

    if not isinstance(model, FrameworkModelType):
        raise TypeError(f"Given model ({model}) is not a FrameworkModelType.")

    options = FrameworkOptions()

    with bentoml.models.create(
        name,
        module=MODULE_NAME,
        api_version=API_VERSION,
        signatures=signatures,
        labels=labels,
        options=options,
        custom_objects=custom_objects,
        external_modules=external_modules,
        metadata=metadata,
        context=context,
    ) as bento_model:
        model.save(bento_model.path_of(MODEL_FILENAME))

        return bento_model


def get_runnable(bento_model: bentoml.Model) -> t.Type[bentoml.Runnable]:
    """
    Private API: use :obj:`~bentoml.Model.to_runnable` instead.
    """

    class FrameworkRunnable(bentoml.Runnable):
        SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
        SUPPORTS_CPU_MULTI_THREADING = True  # type: ignore

        def __init__(self):
            super().__init__()
            # check for resources
            self.model = load_model(
                bento_model,
                # bento_model.info.options.*
            )

            available_gpus = os.getenv("CUDA_VISIBLE_DEVICES")
            if available_gpus is not None and available_gpus not in ("", "-1"):
                # assign GPU resources
                ...
            else:
                # assign CPU resources
                ...

            self.predict_fns: dict[str, t.Callable[..., t.Any]] = {}
            for method_name in bento_model.info.signatures:
                self.predict_fns[method_name] = getattr(self.model, method_name)

    def add_runnable_method(method_name: str, options: ModelSignature):
        def _run(self, input_data) -> OutputType:
            ...

        FrameworkRunnable.add_method(
            _run,
            name=method_name,
            batchable=options.batchable,
            batch_dim=options.batch_dim,
            input_spec=options.input_spec,
            output_spec=options.output_spec,
        )

    for method_name, options in bento_model.info.signatures.items():
        add_runnable_method(method_name, options)

    return FrameworkRunnable
