# -*- coding: utf-8 -*-
"""
init.
"""

from __future__ import print_function

import datetime
import logging
import os
import sys
import time

from six import raise_from, reraise
import wandb
from wandb.backend.backend import Backend
from wandb.lib import console as lib_console
from wandb.lib import filesystem, module, reporting
from wandb.old import io_wrap
from wandb.util import sentry_exc

from .wandb_config import parse_config
from .wandb_run import Run, RunDummy, RunManaged
from .wandb_settings import Settings

if wandb.TYPE_CHECKING:  # type: ignore
    from typing import Optional, Union, List, Dict, Any  # noqa: F401

logger = None  # logger configured during wandb.init()


def _set_logger(log_object):
    """Configure module logger."""
    global logger
    logger = log_object


def online_status(*args, **kwargs):
    pass


class _WandbInit(object):
    def __init__(self):
        self.kwargs = None
        self.settings = None
        self.config = None
        self.run = None
        self.backend = None

        self._wl = None
        self._reporter = None

    def setup(self, kwargs):
        """Complete setup for wandb.init().

        This includes parsing all arguments, applying them with settings and enabling
        logging.

        """
        self.kwargs = kwargs

        # Some settings should be persisted across multiple runs the first
        # time setup is called.
        # TODO: Is this the best way to do this?
        session_settings_keys = ["anonymous"]
        session_settings = {k: kwargs[k] for k in session_settings_keys}
        self._wl = wandb.setup(settings=session_settings)
        # Make sure we have a logger setup (might be an early logger)
        _set_logger(self._wl._get_logger())

        settings = self._wl.settings(
            dict(kwargs.pop("settings", None) or tuple())
        )

        self._reporter = reporting.setup_reporter(
            settings=settings.duplicate().freeze()
        )

        # Remove parameters that are not part of settings
        init_config = kwargs.pop("config", None) or dict()
        if not isinstance(init_config, dict):
            init_config = parse_config(init_config)

        # merge config with sweep (or config file)
        self.config = self._wl._config or dict()
        for k, v in init_config.items():
            self.config.setdefault(k, v)

        # Temporarily unsupported parameters
        unsupported = (
            "magic",
            "config_exclude_keys",
            "config_include_keys",
            "allow_val_change",
            "resume",
            "force",
            "tensorboard",
            "sync_tensorboard",
            "monitor_gym",
        )
        for key in unsupported:
            val = kwargs.pop(key, None)
            if val:
                self._reporter.warning(
                    "currently unsupported wandb.init() arg: %s", key
                )

        # prevent setting project, entity if in sweep
        # TODO(jhr): these should be locked elements in the future or at least
        #            moved to apply_init()
        if settings.sweep_id:
            for key in ("project", "entity"):
                val = kwargs.pop(key, None)
                if val:
                    print("Ignored wandb.init() arg %s when running a sweep" % key)
        settings.apply_init(kwargs)

        # TODO(jhr): should this be moved? probably.
        d = dict(_start_time=time.time(), _start_datetime=datetime.datetime.now(),)
        settings.update(d)

        if settings.jupyter:
            self._jupyter_setup()

        self._log_setup(settings)

        self.settings = settings.freeze()

    def _enable_logging(self, log_fname, run_id=None):
        """Enable logging to the global debug log.  This adds a run_id to the log,
        in case of muliple processes on the same machine.

        Currently no way to disable logging after it's enabled.
        """
        handler = logging.FileHandler(log_fname)
        handler.setLevel(logging.INFO)

        class WBFilter(logging.Filter):
            def filter(self, record):
                record.run_id = run_id
                return True

        if run_id:
            formatter = logging.Formatter(
                "%(asctime)s %(levelname)-7s %(threadName)-10s:%(process)d "
                "[%(run_id)s:%(filename)s:%(funcName)s():%(lineno)s] %(message)s"
            )
        else:
            formatter = logging.Formatter(
                "%(asctime)s %(levelname)-7s %(threadName)-10s:%(process)d "
                "[%(filename)s:%(funcName)s():%(lineno)s] %(message)s"
            )

        handler.setFormatter(formatter)
        if run_id:
            handler.addFilter(WBFilter())
        logger.propagate = False
        logger.setLevel(logging.DEBUG)
        logger.addHandler(handler)

    def _safe_symlink(self, base, target, name, delete=False):
        # TODO(jhr): do this with relpaths, but i cant figure it out on no sleep
        if not hasattr(os, "symlink"):
            return

        pid = os.getpid()
        tmp_name = "%s.%d" % (name, pid)
        owd = os.getcwd()
        os.chdir(base)
        if delete:
            try:
                os.remove(name)
            except OSError:
                pass
        target = os.path.relpath(target, base)
        os.symlink(target, tmp_name)
        os.rename(tmp_name, name)
        os.chdir(owd)

    def _jupyter_setup(self):
        self.notebook = wandb.jupyter.Notebook()
        ipython = self.notebook.shell
        ipython.register_magics(wandb.jupyter.WandBMagics)

        # Monkey patch ipython publish to capture displayed outputs
        if not hasattr(ipython.display_pub, "_orig_publish"):
            ipython.display_pub._orig_publish = ipython.display_pub.publish

        def publish(data, metadata=None, **kwargs):
            ipython.display_pub._orig_publish(data, metadata=metadata, **kwargs)
            self.notebook.save_display(
                ipython.execution_count, {"data": data, "metadata": metadata}
            )

        ipython.display_pub.publish = publish
        # TODO: should we reset start or any other fancy pre or post run cell magic?
        # ipython.events.register("pre_run_cell", reset_start)

    def _log_setup(self, settings):
        """Setup logging from settings."""

        settings.log_user = settings._path_convert(
            settings.log_dir_spec, settings.log_user_spec
        )
        settings.log_internal = settings._path_convert(
            settings.log_dir_spec, settings.log_internal_spec
        )
        settings.sync_file = settings._path_convert(
            settings.sync_dir_spec, settings.sync_file_spec
        )
        settings.files_dir = settings._path_convert(settings.files_dir_spec)
        filesystem._safe_makedirs(os.path.dirname(settings.log_user))
        filesystem._safe_makedirs(os.path.dirname(settings.log_internal))
        filesystem._safe_makedirs(os.path.dirname(settings.sync_file))
        filesystem._safe_makedirs(settings.files_dir)

        log_symlink_user = settings._path_convert(settings.log_symlink_user_spec)
        log_symlink_internal = settings._path_convert(
            settings.log_symlink_internal_spec
        )
        sync_symlink_latest = settings._path_convert(settings.sync_symlink_latest_spec)

        if settings.symlink:
            self._safe_symlink(
                os.path.dirname(sync_symlink_latest),
                os.path.dirname(settings.sync_file),
                os.path.basename(sync_symlink_latest),
                delete=True,
            )
            self._safe_symlink(
                os.path.dirname(log_symlink_user),
                settings.log_user,
                os.path.basename(log_symlink_user),
                delete=True,
            )
            self._safe_symlink(
                os.path.dirname(log_symlink_internal),
                settings.log_internal,
                os.path.basename(log_symlink_internal),
                delete=True,
            )

        _set_logger(logging.getLogger("wandb"))
        self._enable_logging(settings.log_user)

        logger.info("Logging user logs to {}".format(settings.log_user))
        logger.info("Logging internal logs to {}".format(settings.log_internal))

        self._wl._early_logger_flush(logger)

    def init(self):
        s = self.settings
        config = self.config

        if s.reinit:
            if len(self._wl._global_run_stack) > 0:
                if len(self._wl._global_run_stack) > 1:
                    wandb.termwarn(
                        "If you want to track multiple runs concurrently in wandb you should use multi-processing not threads"  # noqa: E501
                    )
                wandb.join()

        if s.mode == "noop":
            # TODO(jhr): return dummy object
            return None

        # Make sure we are logged in
        wandb.login()

        console = s.console
        use_redirect = True
        stdout_master_fd, stderr_master_fd = None, None
        stdout_slave_fd, stderr_slave_fd = None, None
        if console == "iowrap":
            stdout_master_fd, stdout_slave_fd = io_wrap.wandb_pty(resize=False)
            stderr_master_fd, stderr_slave_fd = io_wrap.wandb_pty(resize=False)
        elif console == "_win32":
            # Not used right now
            stdout_master_fd, stdout_slave_fd = lib_console.win32_create_pipe()
            stderr_master_fd, stderr_slave_fd = lib_console.win32_create_pipe()

        backend = Backend(mode=s.mode)
        backend.ensure_launched(
            settings=s,
            stdout_fd=stdout_master_fd,
            stderr_fd=stderr_master_fd,
            use_redirect=use_redirect,
        )
        backend.server_connect()

        # resuming needs access to the server, check server_status()?

        run = RunManaged(config=config, settings=s)
        run._set_console(
            use_redirect=use_redirect,
            stdout_slave_fd=stdout_slave_fd,
            stderr_slave_fd=stderr_slave_fd,
        )
        run._set_library(self._wl)
        run._set_backend(backend)
        run._set_reporter(self._reporter)
        # TODO: pass mode to backend
        # run_synced = None

        self._wl._global_run_stack.append(run)

        backend._hack_set_run(run)

        if s.mode == "online":
            ret = backend.interface.send_run_sync(run, timeout=30)
            # TODO: fail on error, check return type
            run._set_run_obj(ret.run)
        elif s.mode in ("offline", "dryrun"):
            backend.interface.send_run(run)
        elif s.mode in ("async", "run"):
            ret = backend.interface.send_run_sync(run, timeout=10)
            # TODO: on network error, do async run save
            backend.interface.send_run(run)

        self.run = run
        self.backend = backend
        module.set_global(
            run=run,
            config=run.config,
            log=run.log,
            join=run.join,
            summary=run.summary,
            save=run.save,
            restore=run.restore,
            use_artifact=run.use_artifact,
            log_artifact=run.log_artifact,
        )
        self._reporter.set_context(run=run)
        run.on_start()

        return run


def getcaller():
    # py2 doesnt have stack_info
    # src, line, func, stack = logger.findCaller(stack_info=True)
    src, line, func = logger.findCaller()[:3]
    print("Problem at:", src, line, func)


def init(
    job_type = None,
    dir=None,
    config = None,  # TODO(jhr): type is a union for argparse/absl
    project = None,
    entity = None,
    reinit = None,
    tags = None,
    team = None,
    group = None,
    name = None,
    notes = None,
    magic = None,  # TODO(jhr): type is union
    config_exclude_keys=None,
    config_include_keys=None,
    anonymous = None,
    disable = None,
    offline = None,
    allow_val_change = None,
    resume=None,
    force=None,
    tensorboard=None,  # alias for sync_tensorboard
    sync_tensorboard=None,
    monitor_gym=None,
    id=None,
    settings = None,
):
    """Initialize a wandb Run.

    Args:
        entity: alias for team.
        team: personal user or team to use for Run.
        project: project name for the Run.

    Raises:
        Exception: if problem.

    Returns:
        wandb Run object

    """
    assert not wandb._IS_INTERNAL_PROCESS
    kwargs = locals()
    try:
        wi = _WandbInit()
        wi.setup(kwargs)
        try:
            run = wi.init()
        except (KeyboardInterrupt, Exception) as e:
            if not isinstance(e, KeyboardInterrupt):
                sentry_exc(e)
            getcaller()
            assert logger
            logger.exception("we got issues")
            if wi.settings.problem == "fatal":
                raise
            if wi.settings.problem == "warn":
                pass
            run = RunDummy()
    except KeyboardInterrupt as e:
        assert logger
        logger.warning("interupted", exc_info=e)
        raise_from(Exception("interrupted"), e)
    except Exception as e:
        assert logger
        logger.error("error", exc_info=e)
        # Need to build delay into this sentry capture because our exit hooks
        # mess with sentry's ability to send out errors before the program ends.
        sentry_exc(e, delay=True)
        reraise(*sys.exc_info())
        #  raise_from(Exception("problem"), e)
    return run
