# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/13_data.torch.ipynb (unless otherwise specified).

__all__ = ['SpectralDataset', 'DataLoaders', 'SNV_transform', 'Noop']

# Cell
#nbdev_comment from __future__ import annotations
import numpy as np

from fastcore.test import *

from .loading import load_kssl
from .selection import (select_y, select_tax_order, select_X)
from .transform import (log_transform_y, SNV)

from sklearn.model_selection import train_test_split

from fastcore.transform import compose

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

# Cell
class SpectralDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        X = self.X[None, idx, :]
        y = self.y[None, idx]
        if self.transform:
            X = self.transform(X)
        return X.astype(np.float32), y.astype(np.float32)

# Cell
class DataLoaders():
    def __init__(self, data, transform=None, batch_size=32):
        """
        Convert numpy error to Pytorch data loaders (generators)
        Args:
            data: tuple as ((X_train, y_train), (X_test, y_test))
            transform: callable class (__class__)

        Returns:
            (training_generator, validation_generator)
        """
        self.data = data
        self.batch_size = batch_size
        self.transform = transform if transform else Noop()

    def loaders(self):
        return (DataLoader(SpectralDataset(X, y, transform=self.transform), batch_size=self.batch_size)
                for X, y in self.data)

# Cell
class SNV_transform():
    def __init__(self):
        None
    def __call__(self, spectrum):
        return SNV().fit_transform(spectrum)

# Cell
class Noop():
    def __init__(self):
        None
    def __call__(self, X):
        return X