import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Callable

import anomed_utils as utils
import falcon
import numpy as np
import requests

from . import deanonymizer


class EvaluationResource:
    def __init__(
        self,
        anonymizer_identifier: str,
        deanonymizer_identifier: str,
        model_filepath: str | Path,
        model_loader: Callable[[str | Path], deanonymizer.SupervisedLearningMIA],
        default_batch_size: int,
        evaluation_data_url: str,
        utility_evaluation_url: str,
    ) -> None:
        self._anon_id = anonymizer_identifier
        self._deanon_id = deanonymizer_identifier
        self._model_filepath = Path(model_filepath)
        self._load_model = model_loader
        self._default_batch_size = default_batch_size
        self._evaluation_data_url = evaluation_data_url
        self._utility_evaluation_url = utility_evaluation_url
        self._timeout = 10.0
        self._loaded_model: deanonymizer.SupervisedLearningMIA = None  # type: ignore
        self._loaded_model_modification_time: datetime = None  # type: ignore
        self._expected_array_labels = ["X", "y"]

    def on_post(self, req: falcon.Request, resp: falcon.Response) -> None:
        self._load_most_recent_model()

        try:
            data_split = req.get_param("data_split", required=True)
        except falcon.HTTPBadRequest:
            error_msg = "Query parameter 'data_split' is missing!"
            logging.exception(error_msg)
            raise falcon.HTTPBadRequest(description=error_msg)

        if data_split not in ["tuning", "validation"]:
            error_msg = "Query parameter data_split needs to be either 'tuning' or "
            "'validation'."
            logging.exception(error_msg)
            raise falcon.HTTPBadRequest(description=error_msg)

        array = utils.get_named_arrays_or_raise(
            self._evaluation_data_url,
            expected_array_labels=self._expected_array_labels,
            params=dict(data_split=data_split),
            timeout=self._timeout,
        )

        (X, y) = (
            array[self._expected_array_labels[0]],
            array[self._expected_array_labels[1]],
        )
        memberships = self._loaded_model.infer_memberships(
            X=X, y=y, batch_size=self._default_batch_size
        )

        evaluation_response = requests.post(
            url=self._utility_evaluation_url,
            data=utils.named_ndarrays_to_bytes(dict(prediction=memberships)),
            params=dict(
                anonymizer=self._anon_id,
                deanonymizer=self._deanon_id,
                data_split=data_split,
            ),
        )
        if evaluation_response.status_code == 201:
            resp.text = json.dumps(
                dict(
                    message=(
                        f"The deanonymizer has been evaluated based on {data_split} data."
                    ),
                    evaluation=evaluation_response.json(),
                )
            )
            resp.status = falcon.HTTP_CREATED
        else:
            error_msg = "Utility evaluation failed."
            logging.exception(error_msg)
            raise falcon.HTTPInternalServerError(description=error_msg)

    def _load_most_recent_model(self) -> None:
        if not self._model_filepath.exists():
            error_msg = "This deanonymizer is not fitted yet."
            logging.exception(error_msg)
            raise falcon.HTTPBadRequest(
                description=error_msg,
            )
        mod_time_from_disk = datetime.fromtimestamp(
            self._model_filepath.stat().st_mtime
        )
        if _is_older(self._loaded_model_modification_time, mod_time_from_disk):
            self._loaded_model = self._load_model(self._model_filepath)
            self._loaded_model_modification_time = mod_time_from_disk
        else:
            # keep the current model as it is already recent enough
            pass


def _is_older(dt1: datetime | None, dt2: datetime) -> bool:
    """Tell whether `dt1` is older (i.e. more in the past) than `dt2`. If `dt1`
    is the same as `dt2`, or even if `dt1` is `None`, output `True`."""
    if dt1 is None:
        return True
    else:
        return dt1 <= dt2


def supervised_learning_MIA_server_factory(
    anonymizer_identifier: str,
    deanonymizer_identifier: str,
    deanonymizer_obj: deanonymizer.SupervisedLearningMIA,
    model_filepath: str | Path,
    default_batch_size: int,
    member_url: str,
    nonmember_url: str,
    evaluation_data_url: str,
    utility_evaluation_url: str,
    model_loader: Callable[[str | Path], deanonymizer.SupervisedLearningMIA],
) -> falcon.App:
    """A factory to create a web application object which hosts an
    `deanonymizer.SupervisedLearningMIA`, a basic membership inference attack
    (MIA) on anonymizers.

    By using this factory, you don't have to worry any web-programming issues,
    as they are hidden from you. The generated web app will feature the
    following routes (more details may be found in this project's openapi
    specification):

    * [GET] `/`
    * [POST] `/fit`
    * [POST] `/evaluate`

    Parameters
    ----------
    anonymizer_identifier : str
        The identifier of the anonymizer under attack.
    deanonymizer_identifier : str
        The identifier of `deanonymizer_obj`.
    deanonymizer_obj : deanonymizer.SupervisedLearningMIA
        A membership inference attack against an anonymizer, which is based on
        the supervised learning paradigm.
    model_filepath : str | Path
        Where to write fitted attacks to disk.
    default_batch_size : int
        Which batch size to use when inferring memberships, if not specified
        otherwise.
    member_url : str
        Where to download the feature array and target array, which are members
        of the training dataset.
    nonmember_url : str
        Where to download the feature array and target array, which are not
        members of the training dataset.
    evaluation_data_url : str
        Where to download the
    utility_evaluation_url : str
        Where to submit inferred memberships to, for evaluation.
    model_loader : Callable[[str  |  Path], deanonymizer.SupervisedLearningMIA]
        _description_

    Returns
    -------
    falcon.App
        A web application object based on the falcon web framework.
    """
    app = falcon.App()

    app.add_route(
        "/", utils.StaticJSONResource(dict(message="Deanonymizer server is alive!"))
    )
    app.add_route(
        "/fit",
        utils.FitResource(
            data_getter=_get_deanonymizer_fit_data(
                deanonymizer=deanonymizer_obj,
                member_url=member_url,
                nonmember_url=nonmember_url,
                timeout=10.0,
            ),
            model=deanonymizer_obj,
            model_filepath=model_filepath,
        ),
    )
    app.add_route(
        "/evaluate",
        EvaluationResource(
            anonymizer_identifier=anonymizer_identifier,
            deanonymizer_identifier=deanonymizer_identifier,
            model_filepath=model_filepath,
            model_loader=model_loader,
            default_batch_size=default_batch_size,
            evaluation_data_url=evaluation_data_url,
            utility_evaluation_url=utility_evaluation_url,
        ),
    )
    return app


def _get_deanonymizer_fit_data(
    deanonymizer: deanonymizer.SupervisedLearningMIA,
    member_url: str,
    nonmember_url: str,
    timeout: float,
    expected_array_labels: list[str] | None = None,
) -> Callable[[], dict[str, np.ndarray]]:
    if expected_array_labels is None:
        expected_array_labels = ["X", "y"]

    def getter():
        fit_data: dict[str, np.ndarray] = {}
        for url, tag in [
            (member_url, "member"),
            (nonmember_url, "nonmember"),
        ]:
            arrays = utils.get_named_arrays_or_raise(
                url,
                expected_array_labels,
                timeout=timeout,
            )

            for var, idx in [("X", 0), ("y", 1)]:
                fit_data[f"{var}_{tag}"] = arrays[expected_array_labels[idx]]
            _validate_array_input(
                fit_data[f"X_{tag}"],
                deanonymizer,
                falcon.HTTP_INTERNAL_SERVER_ERROR,
                f"The deanonymizer is not compatible with the {tag} feature array.",
            )
        return fit_data

    return getter


def _validate_array_input(
    input_array: np.ndarray,
    model: Any,
    error_status: str | int,
    error_msg: str,
) -> None:
    try:
        model.validate_input(input_array)
    except ValueError:
        logging.exception(error_msg)
        raise falcon.HTTPError(status=error_status, description=error_msg)
