# -*- coding: utf-8 -*-
# Copyright (c) Hebes Intelligence Private Company

# This source code is licensed under the Apache License, Version 2.0 found in the
# LICENSE file in the root directory of this source tree.

import copy
from collections import OrderedDict, defaultdict
from functools import reduce
from typing import Dict, Union

import pandas as pd
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted

from ..compose import ModelStructure
from ..compose._parse import parse_encoder_definition
from ..encode import CategoricalEncoder
from ..utils import as_list, check_X, check_y
from .linear import LinearPredictor


class GroupedPredictor(RegressorMixin, BaseEstimator):
    """Construct one predictor per data group.

    The predictor splits data by the different values of a single column and fits one
    estimator per group. Since each of the models in the ensemble predicts on a different
    subset of the input data (an observation cannot belong to more than one clusters),
    the final prediction is generated by vertically concatenating all the individual
    models’ predictions.

    Args:
        group_feature (str): The name of the column of the input dataframe to use as
            the grouping set.
        model_conf (Dict[str, Dict]): A dictionary that includes information about the
            base model's structure.
        feature_conf (Dict[str, Dict], optional): A dictionary that maps feature
            generator names to the classes for the generators' validation and
            creation. Defaults to None.
        estimator_params (dict or tuple of tuples, optional): The parameters to use when
            instantiating a new base estimator. If none are given, default parameters are
            used. Defaults to tuple().
        fallback (bool, optional): Whether or not to fall back to a global model in case a
            group parameter is not found during `.predict()`. Otherwise, an exception will
            be raised. Defaults to False.
    """

    def __init__(
        self,
        *,
        group_feature: str,
        model_conf: Dict[str, Dict],
        feature_conf: Dict[str, Dict] = None,
        estimator_params=tuple(),
        fallback=False,
    ):
        self.group_feature = group_feature
        self.model_conf = model_conf
        self.feature_conf = feature_conf
        self.estimator_params = estimator_params
        self.fallback = fallback
        self.components_ = ModelStructure.from_config(
            model_conf, feature_conf
        ).components
        self.estimators_ = OrderedDict({})
        self.added_features_ = []
        self.encoders_ = {
            "main_effects": defaultdict(dict),
            "interactions": defaultdict(dict),
        }

    @property
    def n_parameters(self):
        try:
            check_is_fitted(self, "fitted_")
        except NotFittedError as exc:
            raise ValueError(
                "The number of parameters is acceccible only after "
                "the model has been fitted"
            ) from exc
        else:
            n_parameters = 0
            for name, est in self.estimators_.items():
                if name != "_global_":
                    n_parameters += est.n_parameters
            return n_parameters

    @property
    def dof(self):
        try:
            check_is_fitted(self, "fitted_")
        except NotFittedError as exc:
            raise ValueError(
                "The degrees of freedom are acceccible only after "
                "the model has been fitted"
            ) from exc
        else:
            dof = 0
            for name, est in self.estimators_.items():
                if name != "_global_":
                    dof += est.dof
            return dof

    def _fit_single_group(self, group_name, model_structure, X, y):
        try:
            params = (
                dict(self.estimator_params) if self.estimator_params is not None else {}
            )

            estimator = LinearPredictor(model_structure=model_structure, **params)
            estimator = estimator.fit(X, y)
        except Exception as e:
            raise type(e)(f"Exception for group {group_name}: {e}")
        else:
            return estimator

    def _predict_single_group(self, group_name, X, include_components):
        """Predict a single group by getting its estimator"""
        try:
            estimator = self.estimators_[group_name]
        except KeyError:
            if self.fallback:
                estimator = self.estimators_["_global_"]
            else:
                raise ValueError(f"Found new group {group_name} during predict")
        finally:
            pred = estimator.predict(X, include_components=include_components)
            if not isinstance(pred, (pd.Series, pd.DataFrame)):
                pred = pd.DataFrame(
                    data=pred, index=X.index, columns=[self.target_name_]
                )
            return pred

    def _update_local_conf(self, conf, X, y=None, fitting=True):
        for name, props in conf["main_effects"].items():
            if props["type"] == "categorical":
                if fitting:
                    stratify_by = (
                        None
                        if not props["stratify_by"]
                        else [self.group_feature] + props["stratify_by"]
                    )
                    enc = CategoricalEncoder(
                        **dict(
                            {key: val for key, val in props.items() if key != "type"},
                            encode_as="ordinal",
                            stratify_by=stratify_by,
                        )
                    )
                    encoded = enc.fit_transform(X, y).squeeze()
                    self.encoders_["main_effects"][name] = enc
                else:
                    enc = self.encoders_["main_effects"][name]
                    encoded = enc.transform(X).squeeze()

                new_name = "__for__".join((props.get("feature"), name))
                X[new_name] = encoded
                props.update(
                    {
                        "feature": new_name,
                        "max_n_categories": None,
                        "stratify_by": None,
                    }
                )

        for pair_name, pair_props in conf["interactions"].items():
            for name in pair_name:
                props = pair_props[name]
                if props["type"] == "categorical":
                    if fitting:
                        stratify_by = (
                            None
                            if not props["stratify_by"]
                            else [self.group_feature] + props["stratify_by"]
                        )
                        enc = CategoricalEncoder(
                            **dict(
                                {
                                    key: val
                                    for key, val in props.items()
                                    if key != "type"
                                },
                                encode_as="ordinal",
                                stratify_by=stratify_by,
                            )
                        )
                        encoded = enc.fit_transform(X, y).squeeze()
                        self.encoders_["interactions"][pair_name].update({name: enc})
                    else:
                        enc = self.encoders_["interactions"][pair_name][name]
                        encoded = enc.transform(X).squeeze()

                    new_name = "__for__".join(
                        (props.get("feature"), ":".join(pair_name))
                    )
                    X[new_name] = encoded
                    props.update(
                        {
                            "feature": new_name,
                            "max_n_categories": None,
                            "stratify_by": None,
                        }
                    )
        return (conf, X) if fitting else X

    def _create_new_features(self):
        added_features = self.components_.pop("add_features")
        self.components_["add_features"] = {}

        if added_features:
            for _, props in added_features.items():
                fgen_type = props.pop("type")
                if isinstance(fgen_type, str):
                    if self.feature_conf is None:
                        raise ValueError(
                            "A mapping between types and classes has not been provided."
                        )

                    targets = self.feature_conf.get(fgen_type)
                    if targets is None:
                        raise ValueError(
                            f"Type {fgen_type} not found in provided mapping"
                        )

                    class_obj = parse_encoder_definition(targets["generate"])
                    self.added_features_.append(class_obj(**props))
                else:
                    self.added_features_.append(fgen_type)

    def fit(self, X: pd.DataFrame, y: Union[pd.DataFrame, pd.Series]):
        """Fit the estimator with the available data.

        Args:
            X (pandas.DataFrame): Input data.
            y (pandas.Series or pandas.DataFrame): Target data.

        Raises:
            Exception: If the estimator is re-fitted. An estimator object can only be
                fitted once.
            ValueError: If the input data does not pass the checks of `utils.check_X`.
            ValueError: If the target data does not pass the checks of `utils.check_y`.

        Returns:
            GroupedPredictor: Fitted estimator.
        """
        try:
            check_is_fitted(self, "fitted_")
        except NotFittedError:
            pass
        else:
            raise Exception(
                "Estimator object can only be fit once. Instantiate a new object."
            )

        self._create_new_features()
        # Apply the feature generators
        if self.added_features_:
            X = reduce(
                lambda _df, trans: trans.fit_transform(_df), self.added_features_, X
            )

        X = check_X(X, exists=self.group_feature)
        if self.fallback and ("_global_" in X[self.group_feature]):
            raise ValueError(
                "Name `_global_` is reserved and cannot be used as a group name"
            )
        y = check_y(y, index=X.index)
        self.target_name_ = y.columns[0]

        local_model_conf = copy.deepcopy(self.components_)
        local_model_conf, X = self._update_local_conf(
            local_model_conf, X, y=y, fitting=True
        )

        for group_name, group_data in X.groupby(self.group_feature):
            self.estimators_[group_name] = self._fit_single_group(
                group_name=group_name,
                model_structure=ModelStructure(local_model_conf, self.feature_conf),
                X=group_data.drop(self.group_feature, axis=1),
                y=y.loc[group_data.index],
            )

        if self.fallback:
            self.estimators_["_global_"] = self._fit_single_group(
                group_name="_global_",
                model_structure=ModelStructure(self.components_, self.feature_conf),
                X=X.drop(self.group_feature, axis=1),
                y=y,
            )

        self.groups_ = as_list(self.estimators_.keys())
        self.fitted_ = True
        return self

    def predict(
        self, X: pd.DataFrame, include_clusters=False, include_components=False
    ):
        """Predict given new input data.

        Args:
            X (pandas.DataFrame): Input data.
            include_clusters (bool, optional): Whether to include the added clusters in the
                returned prediction. Defaults to False.
            include_components (bool, optional): Whether to include the contribution of the
                individual components of the model structure in the returned prediction.
                Defaults to False.

        Raises:
            ValueError: If the input data does not pass the checks of `utils.check_X`.

        Returns:
            pandas.DataFrame: The predicted values.
        """
        check_is_fitted(self, "fitted_")
        if self.added_features_:
            X = reduce(
                lambda _df, trans: trans.fit_transform(_df), self.added_features_, X
            )
        X = check_X(X, exists=self.group_feature)

        local_model_conf = copy.deepcopy(self.components_)
        X = self._update_local_conf(local_model_conf, X, fitting=False)

        pred = None
        for group_name, group_data in X.groupby(self.group_feature):
            group_pred = self._predict_single_group(
                group_name=group_name,
                X=group_data.drop(self.group_feature, axis=1),
                include_components=include_components,
            )
            if include_clusters:
                group_pred = pd.concat(
                    (group_pred, group_data[[self.group_feature]]),
                    axis=1,
                    ignore_index=False,
                )
            pred = pd.concat((pred, group_pred), axis=0, ignore_index=False)

        pred = pred.reindex(X.index).dropna()
        return pred
