# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/13_data.torch.ipynb.

# %% auto 0
__all__ = ['SpectralDataset', 'DataLoaders', 'SNV_transform', 'Noop']

# %% ../nbs/13_data.torch.ipynb 3
#nbdev_comment from __future__ import annotations
import numpy as np

from fastcore.test import *

from .loading import load_kssl
from .selection import (select_y, select_tax_order, select_X)
from .transform import (log_transform_y, SNV)

from sklearn.model_selection import train_test_split

from fastcore.transform import compose

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

# %% ../nbs/13_data.torch.ipynb 5
class SpectralDataset(Dataset):
    def __init__(self, X, y, tax_order, transform=None):
        self.X = X
        self.y = y
        self.tax_order = tax_order
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        X = self.X[None, idx, :]
        y = self.y[None, idx]
        tax_order = self.tax_order[None, idx]
        if self.transform:
            X = self.transform(X)
        return X.astype(np.float32), y.astype(np.float32), tax_order.astype(np.intc)

# %% ../nbs/13_data.torch.ipynb 6
class DataLoaders():
    def __init__(self, *args, transform=None, batch_size=32):
        """
        Convert numpy error to Pytorch data loaders (generators)
        Args:
            *args: one or many tuple as ((X_train, y_train, tax_order), (X_test, y_test, tax_order))
            transform: callable class (__class__)

        Returns:
            (training_generator, validation_generator)
        """
        self.data = args
        self.batch_size = batch_size
        self.transform = transform if transform else Noop()

    def loaders(self):
        return (DataLoader(SpectralDataset(X, y, tax_order, transform=self.transform), batch_size=self.batch_size)
                for X, y, tax_order in self.data)

# %% ../nbs/13_data.torch.ipynb 8
class SNV_transform():
    def __init__(self):
        None
    def __call__(self, spectrum):
        return SNV().fit_transform(spectrum)

# %% ../nbs/13_data.torch.ipynb 9
class Noop():
    def __init__(self):
        None
    def __call__(self, X):
        return X
