import io
import copy
import pathlib
import torch
from typing import Union, Sequence, Optional, Callable, Iterable, List
from ._conversions import floats_to_tensor
from pygmalion._model_base import ModelBase
from pygmalion.datasets import download_bytes


class NeuralNetwork(torch.nn.Module, ModelBase):
    """
    Abstract class for neural networks
    Implemented as a simple wrapper around torch.nn.Module
    with 'fit' and 'predict' methods
    """

    def __init__(self):
        torch.nn.Module.__init__(self)
        ModelBase.__init__(self)

    @classmethod
    def load(cls, file_path: Union[str, pathlib.Path, io.IOBase]) -> "NeuralNetwork":
        file = download_bytes(file_path) if str(file_path).startswith("https://") else file_path
        model = torch.load(file, map_location="cpu")
        assert isinstance(model, cls)
        return model

    def save(self, file_path: Union[str, pathlib.Path, io.IOBase],
             overwrite: bool = False, create_dir: bool = False):
        """
        Saves the model to the disk as '.pth' file

        Parameters
        ----------
        file : str or pathlib.Path or file like
            The path where the file must be created
        overwritte : bool
            If True, the file is overwritten
        create_dir : bool
            If True, the directory to the file's path is created
            if it does not exist already
        """
        if not isinstance(file_path, io.IOBase):
            file_path = pathlib.Path(file_path)
            path = file_path.parent
            suffix = file_path.suffix.lower()
            if suffix != ".pth":
                raise ValueError(
                    f"The model must be saved as a '.pth' file, but got '{suffix}'")
            if not(create_dir) and not path.is_dir():
                raise ValueError(f"The directory '{path}' does not exist")
            else:
                path.mkdir(exist_ok=True)
            if not(overwrite) and file_path.exists():
                raise FileExistsError(
                    f"The file '{file_path}' already exists, set 'overwrite=True' to overwrite.")
            torch.save(self, file_path)
        else:
            torch.save(self, file_path)

    def fit(self, training_data: Iterable,
            validation_data: Optional[Iterable] = None,
            optimizer: Optional[torch.optim.Optimizer] = None,
            n_steps: int = 1000,
            learning_rate: Union[float, Callable[[int], float]] = 1.0E-3,
            patience: Optional[int] = None,
            keep_best: bool = True,
            loss: Optional[Callable] = None,
            L1: Optional[float] = None,
            L2: Optional[float] = None,
            metric: Optional[Callable] = None,
            verbose: bool = True):
        """
        Trains a neural network model.

        Parameters
        ----------
        training_data : Iterable of (x, y) or (x, y, weights) data
            The data used to fit the model on.
            A tuple of (x, y[, weights]) or a callable that yields them.
            The type of each element depends on the model kind.
        validation_data : None or Iterable of (x, y) or (x, y, weights) data
            The data used to test for early stoping.
            Similar to training_data or None
        optimizer : torch.optim.Optimizer or None
            optimizer to use for training
        n_steps : int
            The maximum number of optimization steps
        learning_rate : float or Callable
            The learning rate used to update the parameters,
            or a learning rate function of 'step' the number
            of optimization steps performed
        patience : int or None
            The number of steps before early stopping
            (if no improvement for 'patience' steps, stops training early)
            If None, no early stoping is performed
        keep_best : bool
            If True, the model is checkpointed at each step if there was
            improvement,
            and the best model is loaded back at the end of training
        verbose : bool
            If True the loss are displayed at each epoch
        """
        best_step = 0
        best_state = copy.deepcopy(self.state_dict())
        best_metric = None
        train_losses = []
        val_losses = []
        grad_norms = []
        if optimizer is None:
            lr = learning_rate(0) if callable(learning_rate) else learning_rate
            optimizer = torch.optim.Adam(self.parameters(), lr)
        else:
            pass
        try:
            # looping on epochs
            for step in range(n_steps+1):
                # stepping the optimization
                optimizer.step()
                # updating learning rate
                if callable(learning_rate):
                    for g in optimizer.param_groups:
                        g["lr"] = learning_rate(step)
                optimizer.zero_grad()
                # training loss
                self.train()
                train_loss = []
                for batch in training_data:
                    loss = self.loss(*batch)
                    if L1 is not None:
                        loss = loss + L1 * self._norm(self.parameters(), 1)
                    if L2 is not None:
                        loss = loss + L2 * self._norm(self.parameters(), 2)
                    loss.backward()
                    train_loss.append(loss.item())
                n_batches = len(train_loss)
                train_loss = sum(train_loss) / max(1, n_batches)
                train_losses.append(train_loss)
                # averaging gradient over batches
                if n_batches > 1:
                    for p in self.parameters():
                        if p.grad is not None:
                            p.grad /= n_batches
                # gradient norm
                grad_norms.append(self._norm((p.grad for p in self.parameters()), 1, average=False).item())
                # validation data
                self.eval()
                if validation_data is not None:
                    val_loss = []
                    with torch.no_grad():
                        for batch in validation_data:
                            val_loss.append(self.loss(*batch).item())
                    val_loss = sum(val_loss) / max(1, len(val_loss))
                else:
                    val_loss = None
                val_losses.append(val_loss)
                # model checkpointing
                metric = val_loss if val_loss is not None else train_loss
                if best_metric is None or metric < best_metric:
                    best_step = step
                    best_metric = metric
                    if keep_best:
                        best_state = copy.deepcopy(self.state_dict())
                # early stoping
                if patience is not None and (step - best_step) > patience:
                    break
                # message printing
                if verbose:
                    if val_loss is not None:
                        print(f"Step {step}: train loss = {train_loss:.3g}, val loss = {val_loss:.3g}, grad = {grad_norms[-1]:.3g}")
                    else:
                        print(f"Step {step}: train loss = {train_loss:.3g}, grad = {grad_norms[-1]:.3g}")
        except KeyboardInterrupt:
            if verbose:
                print("Training interrupted by the user")
        finally:
            # load the best state
            if keep_best:
                self.load_state_dict(best_state)
        return train_losses, val_losses, grad_norms, best_step if keep_best else None

    def data_to_tensor(self, x: object, y: object,
                        weights: Optional[Sequence[float]] = None,
                        device: Optional[torch.device] = None,
                        **kwargs) -> tuple:
        x = self._x_to_tensor(x, device, **kwargs)
        y = self._y_to_tensor(y, device, **kwargs)
        if weights is not None:
            w = floats_to_tensor(weights, device)
            data = (x, y, w/w.mean())
        else:
            data = (x, y)
        return data

    def predict(self, *args):
        self.eval()
        x = self._x_to_tensor(*args)
        with torch.no_grad():
            y_pred = self(x)
        return self._tensor_to_y(y_pred)
    
    def loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()

    def _x_to_tensor(self, x: object) -> torch.Tensor:
        raise NotImplementedError()

    def _y_to_tensor(self, y: object) -> torch.Tensor:
        raise NotImplementedError()

    def _tensor_to_y(self, T: torch.Tensor) -> object:
        raise NotImplementedError()
    
    @staticmethod
    def _norm(tensors: Iterable[torch.Tensor], order: int,
              average: bool=True):
        """
        returns the norm of the tensors
        (normalized by number of elements)
        """
        n, L = 0, 0.
        for t in tensors:
            L = L + torch.sum(torch.abs(t)**order)
            n += t.numel()
        if average:
            L /= n
        return L**(1/order)


class NeuralNetworkClassifier(NeuralNetwork):
    """
    Abstract class for classifier neural networks
    Implement a 'probabilities' method in addition to the 'NeuralNetwork'
    class methods
    """

    def __init__(self, classes: Iterable[str]):
        super().__init__()
        self.classes = tuple(classes)

    def probabilities(self, *args):
        self.eval()
        x = self._x_to_tensor(*args)
        with torch.no_grad():
            y_pred = self(x)
        return self._tensor_to_proba(y_pred)
    
    def _tensor_to_proba(self, T: torch.Tensor) -> object:
        raise NotImplementedError()

    def data_to_tensor(self, x: object, y: object,
                        weights: Optional[Sequence[float]] = None,
                        class_weights: Optional[Sequence[float]] = None,
                        device: Optional[torch.device] = None,
                        **kwargs) -> tuple:
        x = self._x_to_tensor(x, device, **kwargs)
        y = self._y_to_tensor(y, device, **kwargs)
        if weights is not None:
            w = floats_to_tensor(weights, device)
            w = w/w.mean()
        else:
            w = None
        if class_weights is not None:
            wc = floats_to_tensor(class_weights, device)
            wc = wc/wc.mean()
        else:
            wc = None
        if wc is not None:
            data = (x, y, w, wc)
        elif w is not None:
            data = (x, y, w)
        else:
            data = (x, y)
        return data
