import os
import sys
import time
import json

import abc

import pickle

from shutil import copyfile

from mlpug.utils import get_value_at
from mlpug.trainers.callbacks.callback import Callback

import basics.base_utils as _


class CheckpointManager(Callback, metaclass=abc.ABCMeta):

    def __init__(self,
                 model_hyper_parameters=None,
                 checkpoints_path="../trained-models/",
                 batch_level=True,
                 metric_to_monitor="validation.window_average.perplexity",
                 metric_opt_mode='min',
                 metric_monitor_period=200,  # batches or epochs
                 metric_checkpoint_threshold=None,
                 create_checkpoint_every=200,  # batches or epochs
                 archive_last_model_checkpoint_every=2000,  # batches or epochs
                 force_monitoring_on_epoch=True,
                 base_checkpoint_filename=time.strftime("%d-%m-%Y_%H-%M-%S"),
                 model_checkpoint_filename_ext="m-ckp",
                 training_checkpoint_filename_ext="t-ckp",
                 backup_before_override=True,
                 name="CheckpointManager",
                 **kwargs):
        """

        :param model_hyper_parameters: Dict with hyper parameters of the model, this will be saved as part of
                                       the model checkpoint
        :param checkpoints_path:
        :param batch_level: True if monitoring happens on the scale of batches. If False, monitoring is done
                            on the epoch level
        :param metric_to_monitor: key path to metric value in the log object, e.g. `validation.batch.perplexity`
        :param metric_opt_mode: 'max', 'min'
        :param metric_monitor_period: The period between checks for model quality improvement.
                                      This is in number of batches if `batch_level = True`, else it is a
                                      number of epochs.
        :param metric_checkpoint_threshold: when given, the model quality must be below (or above, depending on
                                            metric_opt_mode) this threshold before a new best model checkpoint can be
                                            saved

        :param create_checkpoint_every: period before saving next training/model checkpoint
                                        (will be overridden after each period)
                                        A 0 value disables this feature.

        :param archive_last_model_checkpoint_every: period before the last available model checkpoint is archived.
                                                    Period must be multiple of create_checkpoint_every period

                                                    A 0 value disables this feature.

        :param force_monitoring_on_epoch: When True, the given metric will also be monitored on every epoch
                                          in the case that monitoring level is batch level
        :param base_checkpoint_filename:
        :param model_checkpoint_filename_ext:
        :param training_checkpoint_filename_ext:
        :param backup_before_override: When True the training and model checkpoints are backed up before they are
                                       overridden by a new version. If backing up fails, no new version will be written
                                       to disk. This gives the user a chance to fix a problem, e.g. disk full, without
                                       interruption of the training process.
        """
        super().__init__(name=name, **kwargs)

        self._model_hyper_parameters = model_hyper_parameters

        self._checkpoints_path = checkpoints_path
        self._model_checkpoint_filename_ext = model_checkpoint_filename_ext
        self._training_checkpoint_filename_ext = training_checkpoint_filename_ext
        self._base_filename = base_checkpoint_filename

        self._batch_level = batch_level

        self._metric_to_monitor = metric_to_monitor
        self._metric_opt_mode = metric_opt_mode

        self._metric_monitor_period = metric_monitor_period
        self._metric_checkpoint_threshold = metric_checkpoint_threshold
        self._create_checkpoint_every = create_checkpoint_every
        self._archive_last_model_checkpoint_every = archive_last_model_checkpoint_every

        self._force_monitoring_on_epoch = force_monitoring_on_epoch

        self._backup_before_override = backup_before_override

        self._best_model_quality = float('Inf') if self._metric_opt_mode == 'min' else -float('Inf')
        self._best_model_iter = None

        self._check_settings()

        self._describe_setup()

    def get_state(self):
        return {
            "best_model_quality": self._best_model_quality,
            "best_model_iter": self._best_model_iter
        }, True

    def set_state(self, state):
        """

        :param state Dict with saved checkpoint state to continue tracking the best model during training
        :return:
        """
        success = True
        if not _.is_dict(state):
            self._log.debug("No initial checkpoint state given, starting with clean slate ...")
            return success

        self._log.debug("Using given initial checkpoint state: \n\n%s" % json.dumps(state, indent=4))

        try:
            self._best_model_quality = state['best_model_quality']
            self._best_model_iter = state['best_model_iter']
        except Exception as e:
            _.log_exception(self._log, "Unable to set checkpoint manager state", e)
            success = False

        return success

    def on_batch_training_completed(self, training_batch, logs):
        if not self._batch_level:
            return True

        return self._monitor('global_iter', logs)

    def on_epoch_completed(self, logs):
        iter_name = 'epoch'

        force_monitoring = False
        if self._batch_level:
            if self._force_monitoring_on_epoch:
                self._log.debug("Epoch completed : checking if model improved and creating checkpoints ... ")
                iter_name = 'global_iter'
                force_monitoring = True
            else:
                return True

        return self._monitor(iter_name, logs, force_monitoring=force_monitoring)

    def on_training_ended(self, stopped_early, stopped_on_error, interrupted, callback_calls_success):
        status = 'ended'
        if stopped_early:
            status = 'stopped early'

        if stopped_on_error:
            status = 'stopped on error'

        if interrupted:
            status = 'interrupted'

        sys.stdout.write("\n")
        sys.stdout.flush()
        self._log.info(f"Training {status}")
        success = True
        if self._create_checkpoint_every > 0:
            self._log.info(f"... storing latest model ...")
            latest_model_fname = self._create_model_checkpoint()
            success = (latest_model_fname is not None)

        if stopped_early or interrupted:
            self._log.info(f"... storing training checkpoint ...")
            training_checkpoint_fname = self._create_training_checkpoint()
            success &= (training_checkpoint_fname is not None)

        if abs(self._best_model_quality) != float('Inf') and self._best_model_iter is not None:
            iter_name = "global iteration" if self._batch_level else "epoch"

            self._log.info(f"Best model quality reached {self._metric_to_monitor}={self._best_model_quality} "
                           f"at {iter_name} {self._best_model_iter}")
        else:
            self._log.warn(f"No good enough model found, no best model checkpoint saved.")

        return success

    def training_checkpoint_file_name(self):
        b_fn = self._base_filename
        ext = self._training_checkpoint_filename_ext

        return os.path.join(self._checkpoints_path, f'{b_fn}-training-checkpoint.{ext}')

    def current_model_file_name(self, training_iter):
        b_fn = self._base_filename
        ext = self._model_checkpoint_filename_ext

        return os.path.join(self._checkpoints_path, f'{b_fn}-model-checkpoint-{training_iter}.{ext}')

    def latest_model_file_name(self):
        b_fn = self._base_filename
        ext = self._model_checkpoint_filename_ext

        return os.path.join(self._checkpoints_path, f'{b_fn}-latest-model-checkpoint.{ext}')

    def best_model_file_name(self):
        b_fn = self._base_filename
        ext = self._model_checkpoint_filename_ext

        return os.path.join(self._checkpoints_path, f'{b_fn}-best-model-checkpoint.{ext}')

    def _describe_setup(self):
        self._log.info(f"Metric to monitor : {self._metric_to_monitor}")
        self._log.info(f"Metric monitor period : {self._metric_monitor_period}")

        time_scale = 'batches' if self._batch_level else 'epochs'
        self._log.info(f"Create last training & model checkpoints every "
                       f"{self._archive_last_model_checkpoint_every} {time_scale}")
        self._log.info(f"Archive last model checkpoint every "
                       f"{self._archive_last_model_checkpoint_every} {time_scale}")

    def _get_model_quality(self, current_logs):
        model_quality = get_value_at(self._metric_to_monitor, current_logs)

        if type(model_quality) is tuple:
            # use the first value as metric value, the other values are auxiliary results meant for other purposes
            model_quality = model_quality[0]

        return model_quality

    # TODO : it would maybe be better to split the monitoring and the checkpointing in to two separated methods
    def _monitor(self, iter_name, logs, force_monitoring=False):
        current = self._get_logs_base(logs)

        training_iter = current[iter_name]

        pretty_iter_name = self.pretty_iter_name(iter_name)

        best_model_fname = None
        success = True
        data_saved = False
        if force_monitoring or \
                ((self._metric_monitor_period > 0) and (training_iter % self._metric_monitor_period == 0)):
            model_quality = self._get_model_quality(current)

            model_quality_valid = model_quality is not None
            model_quality_good_enough = True
            if model_quality_valid and self._metric_checkpoint_threshold is not None:
                if self._metric_opt_mode == 'min':
                    model_quality_good_enough = model_quality <= self._metric_checkpoint_threshold
                elif self._metric_opt_mode == 'max':
                    model_quality_good_enough = model_quality >= self._metric_checkpoint_threshold

            if not model_quality_good_enough:
                self._log.debug("%s : %s : model quality not good enough : %s : %3e " % (pretty_iter_name,
                                                                                         training_iter,
                                                                                         self._metric_to_monitor,
                                                                                         model_quality))

            if model_quality_valid and model_quality_good_enough:
                model_improved = ((self._metric_opt_mode == 'min') and (model_quality < self._best_model_quality)) or \
                                 ((self._metric_opt_mode == 'max') and (model_quality > self._best_model_quality))

                if model_improved:
                    self._log.debug("%s : %s : Model improved : %s : %3e " % (pretty_iter_name,
                                                                              training_iter,
                                                                              self._metric_to_monitor,
                                                                              model_quality))

                    if (self._best_model_iter is not None) and \
                            (iter_name == 'epoch') and \
                            (training_iter < self._best_model_iter):
                        self._log.warn(f"Inconsistency: according to the current training {pretty_iter_name} "
                                       f"({training_iter}), current best model training iter ({self._best_model_iter}) "
                                       f"is in the future. Was the right training checkpoint loaded?")

                    best_model_fname = self._save_current_model_as_best()
                    if best_model_fname:
                        self._best_model_quality = model_quality
                        self._best_model_iter = training_iter

                        data_saved = True
                    else:
                        self._log.error("Unable to save improved model checkpoint")
                        success = False

                self._update_logs(model_improved, logs, current)
            elif not model_quality_valid:
                self._log.error(f"No model quality available, unable to check if we need to save a checkpoint, "
                                f"skipping ...")
                success = False

        if (self._create_checkpoint_every > 0) and (training_iter % self._create_checkpoint_every == 0):
            # Just copy best model if available
            latest_model_fname = self._create_model_checkpoint(file_to_copy=best_model_fname)

            latest_model_saved = (latest_model_fname is not None)
            success &= latest_model_saved

            data_saved |= latest_model_saved

            if (self._archive_last_model_checkpoint_every > 0) and \
                    (training_iter % self._archive_last_model_checkpoint_every == 0):
                if latest_model_saved:
                    copy_success = self._copy(latest_model_fname, self.current_model_file_name(training_iter))

                    success &= copy_success

                    data_saved |= copy_success
                else:
                    self._log.error("Unable to create checkpoint for latest model, unable to archive latest model")
                    success = False

            training_checkpoint_fname = self._create_training_checkpoint()
            training_checkpoint_success = (training_checkpoint_fname is not None)
            data_saved |= training_checkpoint_success

            success &= training_checkpoint_success

        if data_saved:
            self._log.debug("Saving of data done.\n\n")

        return success

    def _save_current_model_as_best(self):
        model_fn = self.best_model_file_name()

        if self._backup_before_override:
            try:
                self._backup_checkpoint(model_fn)
            except Exception as e:
                _.log_exception(self._log, f"A problem occurred backing up the last best model checkpoint, "
                                           f"will not override override model checkpoint with new one", e)
                return None

        return self._create_model_checkpoint(model_fn=model_fn)

    def _update_logs(self, model_improved, logs, current):
        if model_improved:
            logs["best"] = current.copy()

        current["is_best"] = model_improved

    def _create_model_checkpoint(self, model_fn=None, file_to_copy=None):
        if model_fn is None:
            model_fn = self.latest_model_file_name()

        if file_to_copy:
            if not self._copy(file_to_copy, model_fn):
                self._log.error(f"Unable to create model checkpoint based on file : {file_to_copy}")
                return None
        else:
            try:
                state, success = self._gather_model_checkpoint_data()
                if state is not None:
                    if not success:
                        self._log.warn("Gathering the model checkpoint data was not completely successful, "
                                       "will save available checkpoint data anyway ...")

                    self._log.debug(f"Saving model checkpoint : {model_fn}")
                    self._save_model_checkpoint(model_fn, state)
                else:
                    return None
            except Exception as e:
                _.log_exception(self._log, f"A problem occurred saving the latest model as checkpoint", e)
                return None

        return model_fn

    def _create_training_checkpoint(self):
        checkpoint_fname = self.training_checkpoint_file_name()

        if self._backup_before_override:
            try:
                if not self._backup_checkpoint(checkpoint_fname):
                    self._log.error("Unable to backup last training checkpoint, "
                                    "will not override override training checkpoint with new one")
                    return None
            except Exception as e:
                _.log_exception(self._log, f"A problem occurred backing up the last training checkpoint, "
                                           f"will not override override training checkpoint with new one", e)
                return None

        try:
            state, success = self._gather_training_checkpoint_data()
            if state is not None:
                if not success:
                    self._log.warn("Gathering the training checkpoint data was not completely successful, "
                                   "will save available checkpoint data anyway ...")

                self._log.debug(f"Saving training checkpoint : {checkpoint_fname}")
                self._save_training_checkpoint(checkpoint_fname, state)
            else:
                return None
        except Exception as e:
            _.log_exception(self._log, f"Unable to save training checkpoint", e)
            return None

        return checkpoint_fname

    def _gather_model_checkpoint_data(self):
        """

        :return: state, success
        """
        state, success = self.trainer.get_model_components_state()

        if state is not None:
            if not success:
                self._log.warn("Getting the model components state was not completely successful, "
                               "continuing anyway ...")

            if self._model_hyper_parameters is not None:
                state['hyper_parameters'] = self._model_hyper_parameters

            try:
                manager_state, manager_state_success = self.training_manager.get_state_for_model_checkpoint()
                state['manager_state'] = manager_state

                if not manager_state_success:
                    self._log.warn("Getting the manager state for the model checkpoint was not successful, "
                                   "will continue anyway ...")
                    success = False

            except Exception as e:
                _.log_exception(self._log, "Unable to add manager state to model checkpoint, "
                                           "continuing anyway ...", e)
                success = False
        else:
            # In any case when the state is None, gathering the model checkpoint data is not successful
            success = False

        return state, success

    def _gather_training_checkpoint_data(self):
        """

        :return: state, success
        """
        return self.training_manager.get_state()

    def _save_model_checkpoint(self, filename, state):
        with open(filename, 'wb') as f:
            pickle.dump(state, f)

    def _save_training_checkpoint(self, filename, state):
        with open(filename, 'wb') as f:
            pickle.dump(state, f)

    def _backup_checkpoint(self, filename):
        if os.path.isfile(filename):
            return self._copy(filename, f"{filename}.backup")
        else:
            return True

    def _copy(self, source_fname, dest_fname):
        try:
            self._log.debug("Copying model: [%s] ==> [%s]" % (source_fname, dest_fname))
            copyfile(source_fname, dest_fname)
        except Exception as e:
            _.log_exception(self._log, "Unable to copy [%s]" % source_fname, e)
            return False

        return True

    def _check_settings(self):
        if not os.path.exists(self._checkpoints_path):
            self._log.info(f"Creating checkpoint directory : {self._checkpoints_path}")
            os.makedirs(self._checkpoints_path)

        if (self._create_checkpoint_every < 0) and (self._archive_last_model_checkpoint_every > 0):
            self._log.error("archive_last_model_checkpoint_every can't be > 0 while _create_checkpoint_every < 0, "
                            "disabling archiving ... ")
            self._archive_last_model_checkpoint_every = -1

        if (self._create_checkpoint_every > 0) and \
           (self._archive_last_model_checkpoint_every > 0) and \
           (self._archive_last_model_checkpoint_every % self._create_checkpoint_every != 0):

            self._archive_last_model_checkpoint_every = 10 * self._create_checkpoint_every
            self._log.warn(f"archive_last_model_checkpoint_every must be "
                           f"exact multiple of _create_checkpoint_every, "
                           f"changing archive_last_model_checkpoint_every to "
                           f"[{self._archive_last_model_checkpoint_every}]")

    @staticmethod
    def pretty_iter_name(iter_name):
        if iter_name == "epoch":
            return "Epoch"
        elif iter_name == "global_iter":
            return "Global iter"
        else:
            return iter_name
