# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/10_training.plsr.ipynb (unless otherwise specified).

__all__ = ['compute_valid_curve', 'PLS_model', 'Evaluator']

# Cell
#nbdev_comment from __future__ import annotations

# Python utils
from collections import OrderedDict
from tqdm.auto import tqdm

# mirzai utils
from ..data.loading import load_kssl
from ..data.selection import (select_y, select_tax_order, select_X)
from ..data.transform import (log_transform_y, SNV, TakeDerivative,
                                   DropSpectralRegions, CO2_REGION)
from .metrics import eval_reg

# Data science stack
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import mean_squared_error

from fastcore.test import *
from fastcore.transform import compose

# Cell
def compute_valid_curve(X:np.ndarray, # Spectra with shape (n_samples, n_wavenumbers)
                        y:np.ndarray, # Target with shape (n_samples)
                        X_names:np.ndarray, # Wavenumbers name with shape (n_wavenumbers)
                        mask:np.ndarray=None, # Mask (e.g for specific Soil Taxonomy Orders) with shape (n_samples)
                        pls_components:list=range(5, 120, 5), # List of PLSR components to try
                        seeds:list=range(20), # List of random seeds to use for multiple train/test splits
                        test_size:float=0.1 # Train/test ratio
                       ):
    "Train the PLSR model on training & evaluate it on the valid set as # pls components increases"
    history = OrderedDict({'pls_components': pls_components,
                           'train_score': [],
                           'valid_score': []})
    for seed in tqdm(seeds):
        if mask is None:
            mask = np.full(len(X), True)
        X_train, X_valid, y_train, y_valid = train_test_split(X[mask, : ], y[mask],
                                                              test_size=test_size,
                                                              random_state=seed)

        train_score = []
        valid_score = []
        for cpts in tqdm(pls_components):
            pipe = Pipeline([('snv', SNV()),
                            ('derivative', TakeDerivative(window_length=11, polyorder=1)),
                            ('dropper', DropSpectralRegions(X_names, regions=CO2_REGION)),
                            ('model', PLSRegression(n_components=cpts))])

            pipe.fit(X_train, y_train)

            train_score.append(mean_squared_error(pipe.predict(X_train), y_train))
            valid_score.append(mean_squared_error(pipe.predict(X_valid), y_valid))

        history['train_score'].append(train_score)
        history['valid_score'].append(valid_score)

    return history

# Cell
class PLS_model():
    "Partial Least Squares model runner"
    def __init__(self, X_names, pipeline_kwargs={}):
        self.X_names = X_names
        self.pipeline_kwargs = pipeline_kwargs
        self.model = None

    def fit(self, data):
        X, y = data
        self.model = Pipeline([
            ('snv', SNV()),
            ('derivative', TakeDerivative(**self.pipeline_kwargs['derivative'])),
            ('dropper', DropSpectralRegions(self.X_names, **self.pipeline_kwargs['dropper'])),
            ('model', PLSRegression(**self.pipeline_kwargs['model']))])
        self.model.fit(X, y)
        return self

    def predict(self, data):
        X, y = data
        return (self.model.predict(X), y)

    def eval(self, data, is_log=True):
        X, y = data
        return eval_reg(y, self.model.predict(X))

# Cell
class Evaluator():
    def __init__(self, data, depth_order, X_names,
                 seeds=range(20), pipeline_kwargs={}, split_ratio=0.1):
        self.seeds = seeds
        self.X, self.y = data
        self.X_names = X_names
        self.depth_order = depth_order
        self.split_ratio = split_ratio
        self.pipeline_kwargs = pipeline_kwargs
        self.models = []
        self.perfs = OrderedDict({'train': [], 'test': []})

    def train_multiple(self):
        for seed in tqdm(self.seeds):
            X_train, X_test, y_train, y_test, depth_order_train, depth_order_test = self._splitter(seed)
            model = PLS_model(self.X_names, self.pipeline_kwargs)
            model.fit((X_train, y_train))
            self.models.append(model)

    def eval_on_train(self, reducer):
        perfs = []
        for i, seed in enumerate(self.seeds):
            X_train, X_test, y_train, y_test, _, _ = self._splitter(seed)
            perf = self.models[i].eval((X_train, y_train))
            perf['n'] = len(X_train)
            perfs.append(perf)
        if reducer:
            perfs = self.reduce(perfs, reducer)
        return perfs

    def eval_on_test(self, order=-1, reducer=None):
        perfs = []
        for i, seed in tqdm(enumerate(self.seeds)):
            X_train, X_test, y_train, y_test, depth_order_train, depth_order_test = self._splitter(seed)
            if order != - 1:
                mask = depth_order_test[:, 1] == order
                X_test, y_test = X_test[mask, :], y_test[mask]
            perf = self.models[i].eval((X_test, y_test))
            perf['n'] = len(X_test)
            perfs.append(perf)
        if reducer:
            perfs = self.reduce(perfs, reducer)
        return perfs

    def _splitter(self, seed):
        return train_test_split(self.X, self.y, self.depth_order,
                                test_size=self.split_ratio,
                                random_state=seed)

    def reduce(self, perfs, fn=np.mean):
        results = {}
        for metric in perfs[0].keys():
            result = fn(np.array([perf[metric] for perf in perfs]))
            results[metric] = result
        return results