import atexit
import importlib
import io
import itertools
import os
import threading
from typing import Optional

from PIL import Image

import vessl
from vessl.util import logger
from vessl.util.image import Image as VesslImage

MODE_TENSORFLOW = "mode-tf"
MODE_OTHERS = "mode-others"

# Tensorboard event types
EVENT_TYPE_COMPRESSED_HISTOGRAMS = "distributions"
EVENT_TYPE_HISTOGRAMS = "histograms"
EVENT_TYPE_IMAGES = "images"
EVENT_TYPE_AUDIO = "audio"
EVENT_TYPE_SCALARS = "scalars"
EVENT_TYPE_TENSORS = "tensors"

# EventAccumulator argument (0 = keep all, N = keep N)
SIZE_GUIDANCE = {
    EVENT_TYPE_COMPRESSED_HISTOGRAMS: 1,
    EVENT_TYPE_IMAGES: 0,
    EVENT_TYPE_AUDIO: 0,
    EVENT_TYPE_SCALARS: 0,
    EVENT_TYPE_HISTOGRAMS: 1,
    EVENT_TYPE_TENSORS: 1,
}

READ_INTERVAL_IN_SEC = 1


class TensorboardCollector:
    """TensorboardCollector checks for new tensorboard events written to logdir and
    logs them to experiment metrics.

    There are two modes: `MODE_TENSORFLOW` (TF2) and `MODE_OTHERS` (TF1, PyTorch, etc).
    `MODE_OTHERS` uses tensorboard's `EventAccumulator` to read from the event file.
    `MODE_TENSORFLOW` uses tensorflow's `summary_iterator` because TF2 records events
      in a different way.
    """

    __slots__ = [
        "_exit",
        "_thread",
        "_tf_dict",
        "_event_accumulator",
    ]

    def __init__(self):
        self._start_thread()

    @property
    def _mode(self):
        if hasattr(self, "_tf_dict"):
            return MODE_TENSORFLOW
        if hasattr(self, "_event_accumulator"):
            return MODE_OTHERS
        return ""

    def _is_initialized(self):
        return self._mode in (MODE_TENSORFLOW, MODE_OTHERS)

    def _start_thread(self):
        def exit_fn(*args):
            self._exit.set()
            self._thread.join()
            if not self._is_initialized():
                logger.warn("No tensorboard writer was detected during the run.")

        atexit.register(exit_fn)

        self._exit = threading.Event()
        self._thread = threading.Thread(target=self._thread_body, daemon=True)
        self._thread.start()

    def _thread_body(self):
        while not self._exit.is_set():
            self._main()
            self._exit.wait(timeout=READ_INTERVAL_IN_SEC)
        self._main()

    def _main(self):
        if not self._is_initialized():
            # Wait for initialization
            return

        if self._mode == MODE_TENSORFLOW:
            self._main_tensorflow()

        elif self._mode == MODE_OTHERS:
            self._main_others()

    def _main_tensorflow(self):
        """Main thread body for MODE_TENSORFLOW"""
        # Skip already seen events
        for _ in range(self._tf_dict["start_index"]):
            try:
                next(self._tf_dict["summary_iterator"])
            except StopIteration:
                continue

        events = []
        for e in self._tf_dict["summary_iterator"]:
            events.append(e)

        self._tf_dict["start_index"] += len(events)
        for e in events:
            self._handle_event(e)

    def _handle_event(self, event):
        for val_struct in event.summary.value:
            event_type = val_struct.metadata.plugin_data.plugin_name

            if event_type == EVENT_TYPE_SCALARS:
                event_value = self._tf_dict["make_ndarray_fn"](val_struct.tensor).item()

            elif event_type == EVENT_TYPE_IMAGES:
                # First two tensors are height and width, image bytestrings start
                # from third tensor.
                event_value = [
                    VesslImage(data=Image.open(io.BytesIO(bs)), caption=val_struct.tag)
                    for bs in val_struct.tensor.string_val[2:]
                ]

            else:
                continue

            payload = {val_struct.tag: event_value}
            vessl.log(payload=payload, step=event.step, ts=event.wall_time)

    def _main_others(self):
        """Main thread body for MODE_OTHERS"""
        try:
            self._event_accumulator.Reload()  # Loads each event at most once
        except Exception:
            # No metrics have been saved yet
            return

        event_type_to_tags = self._event_accumulator.Tags()
        # ex. {"scalars": ["loss", "accuracy"], "images": ["caption"]}

        self._handle_scalars(event_type_to_tags[EVENT_TYPE_SCALARS])
        self._handle_images(event_type_to_tags[EVENT_TYPE_IMAGES])

    def _handle_scalars(self, tags):
        for tag in tags:
            scalars = self._event_accumulator.Scalars(tag)
            self._flush_scalars(tag)  # Flush right away
            for s in scalars:
                self._log_scalar(tag, s)

    def _flush_scalars(self, tag):
        with self._event_accumulator.scalars._mutex:
            scalars = self._event_accumulator.scalars
        with scalars._buckets[tag]._mutex:
            scalars._buckets[tag].items = []

    def _log_scalar(self, tag, scalar):
        payload = {tag: scalar.value}
        vessl.log(payload=payload, step=scalar.step, ts=scalar.wall_time)

    def _handle_images(self, tags):
        for tag in tags:
            images = self._event_accumulator.Images(tag)
            self._flush_images(tag)  # Flush right away
            for i in images:
                self._log_image(tag, i)

    def _flush_images(self, tag):
        with self._event_accumulator.scalars._mutex:
            images = self._event_accumulator.images
        with images._buckets[tag]._mutex:
            images._buckets[tag].items = []

    def _log_image(self, tag, image):
        vessl_image = VesslImage(
            data=Image.open(io.BytesIO(image.encoded_image_string)), caption=tag
        )

        payload = {tag: vessl_image}
        vessl.log(payload=payload, step=image.step, ts=image.wall_time)

    def set_logdir(self, logdir, mode):
        """Called when tensorboard writer is detected. Only the first logdir will be used."""
        if not self._is_initialized():
            logger.info(f"Tensorboard logdir detected: {logdir}.")
            if mode == MODE_TENSORFLOW:
                self._initialize_mode_tensorflow(logdir)
            elif mode == MODE_OTHERS:
                self._initialize_mode_others(logdir)
        else:
            logger.info(
                f"Cannot use multiple tensorboard logdirs. {logdir} will be ignored."
            )

    def _initialize_mode_tensorflow(self, logdir):
        try:
            from tensorflow import make_ndarray
            from tensorflow.compat.v1.train import summary_iterator
        except Exception:
            logger.error(
                f"Could not import tensorflow. Failed to integrate tensorboard."
            )
            return

        # Get most recent event file. The event file for this run will have been just
        # created (in `tf.python.ops.gen_summary_ops.create_summary_file_writer`).
        path = mtime = None
        for dirpath, _, file_names in os.walk(logdir):
            for file_name in file_names:
                file_path = os.path.join(dirpath, file_name)
                file_mtime = os.path.getmtime(file_path)
                if "tfevents" in file_name and (mtime is None or mtime < file_mtime):
                    path, mtime = file_path, file_mtime

        if path is None:
            logger.info(f"No event file found in logdir: {logdir}.")
            return

        # Dict for MODE_TENSORFLOW objects
        self._tf_dict = {
            "make_ndarray_fn": make_ndarray,
            "summary_iterator": summary_iterator(path),
            "start_index": 0,
        }

    def _initialize_mode_others(self, logdir):
        try:
            from tensorboard.backend.event_processing.event_accumulator import (
                EventAccumulator,
            )
        except Exception:
            logger.error(
                f"Could not import tensorboard. (Please install using `pip install tensorboard` first.) "
                f"Failed to integrate tensorboard."
            )
            return

        self._event_accumulator = EventAccumulator(logdir, size_guidance=SIZE_GUIDANCE)

        try:
            self._event_accumulator.Reload()
        except Exception:
            # Logdir doesn't exist yet
            return

        # Disregard preexisting tensorboard logs
        for type, tags in self._event_accumulator.Tags().items():
            if type == EVENT_TYPE_SCALARS:
                for tag in tags:
                    self._flush_scalars(tag)
            if type == EVENT_TYPE_IMAGES:
                for tag in tags:
                    self._flush_images(tag)


def integrate_tensorboard():
    """Integrate tensorboard

    Called from `vessl.init`. Should be called at most once.
    """
    tc = TensorboardCollector()

    _patch_create_summary_file_writer(tc)
    _patch_event_file_writers(tc)
    return tc


def _patch_create_summary_file_writer(tc):
    module_name = "tensorflow.python.ops.gen_summary_ops"
    try:
        module = importlib.import_module(module_name)
    except Exception as e:
        logger.debug(f"Module {module_name} not found. Skipping patch...")
        return

    _create_summary_file_writer = module.create_summary_file_writer

    def custom_create_summary_file_writer(*args, **kwargs):
        ret = _create_summary_file_writer(*args, **kwargs)
        logdir = (
            kwargs["logdir"].numpy().decode("utf8")
            if hasattr(kwargs["logdir"], "numpy")
            else kwargs["logdir"]
        )
        # Calling `set_logdir` after `_create_summary_file_writer` ensures that the
        # event file has already been created. This is important because we use the
        # most recent file in `logdir` as our event file.
        tc.set_logdir(logdir, MODE_TENSORFLOW)
        return ret

    module.create_summary_file_writer = custom_create_summary_file_writer
    logger.debug(f"Module {module_name} was successfully patched.")


def _patch_event_file_writers(tc):
    module_names = [
        "tensorflow.python.summary.writer.writer",  # TF1
        "tensorboard.summary.writer.event_file_writer",
        "torch.utils.tensorboard.writer",
        "tensorboardX.writer",
    ]

    for module_name in module_names:
        try:
            module = importlib.import_module(module_name)
        except Exception as e:
            logger.debug(f"Module {module_name} not found. Skipping patch...")
            continue

        _patch_event_file_writer(module, tc)
        logger.debug(f"Module {module_name} was successfully patched.")


def _patch_event_file_writer(module, tc):
    _EventFileWriter = module.EventFileWriter

    class CustomEventFileWriter(_EventFileWriter):
        def __init__(self, *args, **kwargs):
            # self.log_dir is set in super().__init__
            super(CustomEventFileWriter, self).__init__(*args, **kwargs)

            logdir = self.get_logdir()
            tc.set_logdir(logdir, MODE_OTHERS)

    module.EventFileWriter = CustomEventFileWriter
