# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/014_data.unwindowed.ipynb.

# %% auto 0
__all__ = ['TSUnwindowedDataset', 'TSUnwindowedDatasets']

# %% ../../nbs/014_data.unwindowed.ipynb 3
from ..imports import *
from ..utils import *
from .validation import *
from .core import *

# %% ../../nbs/014_data.unwindowed.ipynb 4
class TSUnwindowedDataset():
    _types = TSTensor, TSLabelTensor
    def __init__(self, X=None, y=None, y_func=None, window_size=1, stride=1, drop_start=0, drop_end=0, seq_first=True, **kwargs):
        store_attr()
        if X is not None:
            if X.ndim == 1: X = np.expand_dims(X, 1)
            shape = X.shape
            assert len(shape) == 2
            if seq_first: 
                seq_len = shape[0]
            else: 
                seq_len = shape[-1]
            max_time = seq_len - window_size + 1 - drop_end
            assert max_time > 0, 'you need to modify either window_size or drop_end as they are larger than seq_len'
            self.all_idxs = np.expand_dims(np.arange(drop_start, max_time, step=stride), 0).T
            self.window_idxs = np.expand_dims(np.arange(window_size), 0)
            if 'split' in kwargs: self.split = kwargs['split']
            else: self.split = None
            self.n_inp = 1
            if y is None: 
                self.loss_func = MSELossFlat()
            else: 
                if (is_listy(y[0]) and isinstance(y[0][0], Integral)) or isinstance(y[0], Integral): 
                    self.loss_func = CrossEntropyLossFlat()
                else: 
                    self.loss_func = MSELossFlat()

    def __len__(self):
        if not hasattr(self, "split"): return 0
        elif self.split is not None: 
            return len(self.split)
        else: 
            return len(self.all_idxs)

    def __getitem__(self, idxs):
        if self.split is not None:
            idxs = self.split[idxs]
        widxs = self.all_idxs[idxs] + self.window_idxs
        if self.seq_first:
            xb = self.X[widxs]
            if xb.ndim == 3: xb = xb.transpose(0,2,1)
            else: xb = np.expand_dims(xb, 1)
        else:
            xb = self.X[:, widxs].transpose(1,0,2)
        if self.y is None:
            return (self._types[0](xb),)
        else:
            yb = self.y[widxs]
            if self.y_func is not None: 
                yb = self.y_func(yb)
            return (self._types[0](xb), self._types[1](yb))
    
    def new_empty(self): 
        return type(self)(X=None, y=None)
    
    @property
    def vars(self):
        s = self[0][0] if not isinstance(self[0][0], tuple) else self[0][0][0]
        return s.shape[-2]
    @property
    def len(self): 
        s = self[0][0] if not isinstance(self[0][0], tuple) else self[0][0][0]
        return s.shape[-1]    


class TSUnwindowedDatasets(FilteredBase):
    def __init__(self, dataset, splits):
        store_attr()
    def subset(self, i):
        return type(self.dataset)(self.dataset.X, y=self.dataset.y, y_func=self.dataset.y_func, window_size=self.dataset.window_size,
                                  stride=self.dataset.stride, drop_start=self.dataset.drop_start, drop_end=self.dataset.drop_end, 
                                  seq_first=self.dataset.seq_first, split=self.splits[i])
    @property
    def train(self): 
        return self.subset(0)
    @property
    def valid(self): 
        return self.subset(1)
    def __getitem__(self, i): return self.subset(i)
