"""
Simple scikit-learn interface for Emb-GAM.


Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models
Chandan Singh & Jianfeng Gao
https://arxiv.org/abs/2209.11799
"""
from numpy.typing import ArrayLike
import numpy as np
from scipy.special import softmax
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.linear_model import LogisticRegressionCV, RidgeCV
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_is_fitted
from spacy.lang.en import English
from sklearn.preprocessing import StandardScaler
import transformers
import imodelsx.embgam.embed
from tqdm import tqdm
import os
import os.path
import warnings
import pickle as pkl
import torch
from sklearn.exceptions import ConvergenceWarning
device = 'cuda' if torch.cuda.is_available() else 'cpu'


class LinearFinetune(BaseEstimator):
    def __init__(
        self,
        checkpoint: str = 'bert-base-uncased',
        layer: str = 'last_hidden_state',
        random_state=None,
        normalize_embs=False,
    ):
        '''LinearFinetune Class - use either LinearFinetuneClassifier or LinearFinetuneRegressor rather than initializing this class directly.

        Parameters
        ----------
        checkpoint: str
            Name of model checkpoint (i.e. to be fetch by huggingface)
        layer: str
            Name of layer to extract embeddings from
        random_state
            random seed for fitting
        normalize_embs
            whether to normalize embeddings before fitting linear model
        '''
        self.checkpoint = checkpoint
        self.layer = layer
        self.random_state = random_state
        self.normalize_embs = normalize_embs
        self.model = transformers.AutoModel.from_pretrained(
            self.checkpoint).to(device)
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            self.checkpoint)

    def fit(
        self,
        X: ArrayLike,
        y: ArrayLike,
        verbose=True,
        cache_embs_dir: str = None,
    ):
        '''Extract embeddings then fit linear model

        Parameters
        ----------
        X: ArrayLike[str]
        y: ArrayLike[str]
        cache_embs_dir, optional
            if not None, directory to save embeddings into
        '''

        # metadata
        if isinstance(self, ClassifierMixin):
            self.classes_ = unique_labels(y)
        if self.random_state is not None:
            np.random.seed(self.random_state)

        # set up model
        if verbose:
            print('initializing model...')

        # get embs
        if verbose:
            print('calculating embeddings...')
        embs = self._get_embs(X)
        if self.normalize_embs:
            self.normalizer = StandardScaler()
            embs = self.normalizer.fit_transform(embs)
        if cache_embs_dir is not None:
            os.makedirs(cache_embs_dir, exist_ok=True)
            pkl.dump(embs, open(os.path.join(cache_embs_dir, 'embs.pkl'), 'wb'))

        # train linear
        warnings.filterwarnings("ignore", category=ConvergenceWarning)
        if verbose:
            print('training linear model...')
        if isinstance(self, ClassifierMixin):
            self.linear = LogisticRegressionCV()
        elif isinstance(self, RegressorMixin):
            self.linear = RidgeCV()
        self.linear.fit(embs, y)

        return self

    def _get_embs(self, X):
        embs = []
        for i in tqdm(range(len(X))):
            inputs = self.tokenizer(
                [X[i]], padding=True, truncation=True, return_tensors="pt")
            inputs = inputs.to(self.model.device)
            output = self.model(**inputs)
            emb = output[self.layer].cpu().detach().numpy()
            if len(emb.shape) == 3:  # includes seq_len
                emb = emb.mean(axis=1)
            embs.append(emb)
            # emb = imodelsx.embgam.embed.embed_and_sum_function(
            #     x,
            #     model=self.model,
            #     tokenizer_embeddings=self.tokenizer,
            #     checkpoint=self.checkpoint,
            #     layer=self.layer,
            #     fit_with_ngram_decomposition=False,
            # )
            # embs.append(emb['embs'])
        return np.array(embs).squeeze()  # num_examples x embedding_size

    # def _get_embs(self, ngrams_list, model, tokenizer_embeddings):
    #     """Get embeddings for a list of ngrams (not summed!)
    #     """
    #     embs = []
    #     for i in tqdm(range(len(ngrams_list))):

    #     embs = np.array(embs).squeeze()
    #     return embs

        """
        # Faster version that needs more memory
        tokens = tokenizer(ngrams_list, padding=args.padding,
                           truncation=True, return_tensors="pt")
        tokens = tokens.to(device)

        output = model(**tokens) # this takes a while....
        embs = output['pooler_output'].cpu().detach().numpy()
        return embs
        """


    def predict(self, X):
        '''For regression returns continuous output.
        For classification, returns discrete output.
        '''
        check_is_fitted(self)
        embs = self._get_embs(X)
        if self.normalize_embs:
            embs = self.normalizer.transform(embs)
        return self.linear.predict(embs)

    def predict_proba(self, X, warn=True):
        check_is_fitted(self)
        embs = self._get_embs(X)
        if self.normalize_embs:
            embs = self.normalizer.transform(embs)
        return self.linear.predict_proba(embs)

class LinearFinetuneRegressor(LinearFinetune, RegressorMixin):
    ...


class LinearFinetuneClassifier(LinearFinetune, ClassifierMixin):
    ...

if __name__ == '__main__':
    import imodelsx.data
    dset, k = imodelsx.data.load_huggingface_dataset('rotten_tomatoes', binary_classification=False, subsample_frac=0.1)
    print(dset)
    print(dset['train'])
    print(np.unique(dset['train']['label']))

    clf = LinearFinetuneClassifier()
    clf.fit(dset['train']['text'], dset['train']['label'])

    print('predicting')
    preds = clf.predict(dset['test']['text'])
    print(preds.shape)

    print('predicting proba')
    preds_proba = clf.predict_proba(dset['test']['text'])
    print(preds_proba.shape)

    assert preds_proba.shape[0] == preds.shape[0]
    print('acc_train', np.mean(clf.predict(dset['train']['text']) == dset['train']['label']))
    print('acc_test', np.mean(preds == dset['test']['label']))