##############################################################################
# Institute for the Design of Advanced Energy Systems Process Systems
# Engineering Framework (IDAES PSE Framework) Copyright (c) 2018-2020, by the
# software owners: The Regents of the University of California, through
# Lawrence Berkeley National Laboratory,  National Technology & Engineering
# Solutions of Sandia, LLC, Carnegie Mellon University, West Virginia
# University Research Corporation, et al. All rights reserved.
#
# Please see the files COPYRIGHT.txt and LICENSE.txt for full copyright and
# license information, respectively. Both files are also available online
# at the URL "https://github.com/IDAES/idaes-pse".
##############################################################################
"""
Visualization server back-end.

The main class is `FlowsheetServer`, which is instantiated from the `visualize()` function.
"""

# stdlib
import http.server
import json
from pathlib import Path
import re
import socket
import threading
from typing import Dict, Union
from urllib.parse import urlparse

# package
from idaes import logger
from ..flowsheet import FlowsheetDiff, FlowsheetSerializer
from . import persist

_log = logger.getLogger(__name__)

# Directories
_this_dir = Path(__file__).parent.absolute()
_static_dir = _this_dir / "static"
_template_dir = _this_dir / "templates"


class FlowsheetNotFound(Exception):
    def __init__(self, id_, location):
        super().__init__(f"Flowsheet {id_} not found in {location}")
        self.location = location  # to help distinguish


class FlowsheetNotFoundInDatastore(FlowsheetNotFound):
    def __init__(self, id_):
        super().__init__(id_, "datastore")


class FlowsheetNotFoundInMemory(FlowsheetNotFound):
    def __init__(self, id_):
        super().__init__(id_, "Python process memory")


class FlowsheetUnknown(Exception):
    def __init__(self, id_):
        super().__init__(f"Unrecognized flowsheet '{id_}'")


class ProcessingError(Exception):
    """Use for errors processing input."""

    pass


class FlowsheetServer(http.server.HTTPServer):
    """A simple HTTP server that runs in its own thread.

    This server is used for *all* models for a given process, so every request needs to contain
    the ID of the model that should be used in that transaction.

    The only methods that the visualization function needs to call are the constructor, `start()` to
     start running the server, and `add_flowsheet()`, to a add a new flowsheet.
    """

    def __init__(self, port=None):
        """Create HTTP server
        """
        self._port = port or find_free_port()
        _log.info(f"Starting HTTP server on localhost, port {self._port}")
        super().__init__(("127.0.0.1", self._port), FlowsheetServerHandler)
        self._dsm = persist.DataStoreManager()
        self._flowsheets = {}
        self._thr = None

    @property
    def port(self):
        return self._port

    def start(self):
        """Start the server, which will spawn a thread.
        """
        self._thr = threading.Thread(target=self._run)
        self._thr.setDaemon(True)
        self._thr.start()

    def add_flowsheet(self, id_, flowsheet, save_as) -> str:
        """Add a flowsheet, and also the method of saving it.

        Args:
            id_: Name of flowsheet
            flowsheet: Flowsheet object
            save_as: File, path, etc. passed to :meth:`persist.DataStore.create()` to save the
                     changes made from the UI.

        Returns:
            Name of flowsheet, modified as necessary to be URL friendly
        """
        # replace all but 'unreserved' (RFC 3896) chars with a dash; remove duplicate dashes
        id_ = re.sub(r"-+", "-", re.sub(r"[^a-zA-Z0-9-._~]", "-", id_))
        self._flowsheets[id_] = flowsheet
        store = persist.DataStore.create(save_as)
        _log.debug(f"Flowsheet '{id_}' storage is {store}")
        self._dsm.add(id_, store)
        # First try to update, so as not to overwrite saved value
        try:
            self.update_flowsheet(id_)
        except FlowsheetNotFoundInDatastore:
            _log.debug(f"No existing flowsheet found in {store}: saving new value")
            # If not found in datastore, save new value
            fs_dict = FlowsheetSerializer(flowsheet, id_).as_dict()
            store.save(fs_dict)
        else:
            _log.debug(f"Existing flowsheet found in {store}: saving merged value")
        return id_

    # === Public methods called only by HTTP handler ===

    def save_flowsheet(self, id_, flowsheet: Union[Dict, str]):
        """Save the flowsheet to the appropriate store.

        Raises:
            ProcessingError, if parsing of JSON failed (see :meth:`DataStoreManager.save()`)
        """
        try:
            self._dsm.save(id_, flowsheet)
        except ValueError as err:
            raise ProcessingError(str(err))

    def update_flowsheet(self, id_: str) -> Dict:
        """Update flowsheet.

        The returned flowsheet is also saved to the datastore.

        Args:
            id_: Identifier of flowsheet to update.

        Returns:
            Merged value of flowsheets in datastore and current value in memory

        Raises:
            FlowsheetUnknown if the flowsheet id is not known
            FlowsheetNotFound (subclass) if the flowsheet id is known, but it can't be retrieved
            ProcessingError for internal errors
        """
        # Get saved flowsheet from datastore
        try:
            saved = self._load_flowsheet(id_)
        except KeyError:
            raise FlowsheetUnknown(id_)
        except ValueError:
            raise FlowsheetNotFoundInDatastore(id_)
        # Get current value from memory
        try:
            obj = self._get_flowsheet_obj(id_)
        except KeyError:
            raise FlowsheetNotFoundInMemory(id_)
        try:
            obj_dict = self._serialize_flowsheet(id_, obj)
        except ValueError as err:
            raise ProcessingError(f"Cannot serialize flowsheet: {err}")
        # Compare saved and current value
        diff = FlowsheetDiff(saved, obj_dict)
        _log.debug(f"diff: {diff}")
        if not diff:
            # If no difference do nothing
            _log.debug("Stored flowsheet is the same as the flowsheet in memory")
            merged = saved
        else:
            # Otherwise, save this merged value before returning it
            num, pl = len(diff), "s" if len(diff) > 1 else ""
            _log.debug(
                f"Stored flowsheet and model in memory differ by {num} item{pl}"
            )
            self.save_flowsheet(id_, diff.merged())
        # Return [a copy of the] merged value
        return diff.merged(do_copy=True)

    # === Internal methods ===

    def _load_flowsheet(self, id_) -> Union[Dict, str]:
        return self._dsm.load(id_)

    def _get_flowsheet_obj(self, id_):
        """Get a flowsheet with the given ID.
        """
        return self._flowsheets[id_]

    @staticmethod
    def _serialize_flowsheet(id_, flowsheet):
        try:
            result = FlowsheetSerializer(flowsheet, id_).as_dict()
        except (AttributeError, KeyError) as err:
            raise ValueError(f"Error serializing flowsheet: {err}")
        return result

    def _run(self):
        """Run in a separate thread.
        """
        _log.debug(f"Serve forever on localhost:{self._port}")
        try:
            self.serve_forever()
        except Exception as err:
            _log.info(f"Shutting down server due to error: {err}")
            self.shutdown()


class FlowsheetServerHandler(http.server.SimpleHTTPRequestHandler):
    """Handle requests from the IDAES flowsheet visualization (IFV) web page.
    """

    def __init__(self, *args, **kwargs):
        self.directory = None  # silence warning about initialization outside constructor
        super().__init__(*args, **kwargs)

    # === GET ===

    def do_GET(self):
        """Process a request to receive data.

        Routes:
          * `/app`: Return the web page
          * `/fs`: Retrieve an updated flowsheet.
          * `/path/to/file`: Retrieve file stored static directory
        """
        u, id_ = self._parse_flowsheet_url(self.path)
        _log.debug(f"do_GET: path={self.path} id=={id_}")
        if u.path in ("/app", "/fs") and id_ is None:
            self.send_error(
                400, message=f"Query parameter 'id' is required for '{u.path}'"
            )
            return
        if u.path == "/app":
            self._get_app(id_)
        elif u.path == "/fs":
            self._get_fs(id_)
        else:
            # Try to serve a file
            self.directory = _static_dir  # keep here: overwritten if set earlier
            super().do_GET()

    def _get_app(self, id_):
        """Read index file, process to insert flowsheet identifier, and return it.
        """
        p = Path(_template_dir / "index.html")
        with open(p, "r") as fp:
            s = fp.read()
            page = s.format(flowsheet_id=id_)
        self._write_html(200, page)

    def _get_fs(self, id_: str):
        """Get updated flowsheet.

        Args:
            id_: Flowsheet identifier

        Returns:
            None
        """
        try:
            merged = self.server.update_flowsheet(id_)
        except FlowsheetUnknown as err:
            # User error: user asked for a flowsheet by an unknown ID
            self.send_error(404, message=str(err))
            return
        except (FlowsheetNotFound, ProcessingError) as err:
            # Internal error: flowsheet ID is found, but other things are missing
            self.send_error(500, message=str(err))
            return
        # Return merged flowsheet
        self._write_json(200, merged)

    # === PUT ===

    def do_PUT(self):
        """Process a request to store data.
        """
        u, id_ = self._parse_flowsheet_url(self.path)
        _log.debug(f"do_PUT: route={u} id={id_}")
        if u.path in ("/fs",) and id_ is None:
            self.send_error(
                400, message=f"Query parameter 'id' is required for '{u.path}'"
            )
        if u.path == "/fs":
            self._put_fs(id_)

    def _put_fs(self, id_):
        # read  flowsheet from request (read(LENGTH) is required to avoid hanging)
        read_len = int(self.headers.get("Content-Length", "-1"))
        data = utf8_decode(self.rfile.read(read_len))
        # save flowsheet
        try:
            self.server.save_flowsheet(id_, data)
        except ProcessingError as err:
            self.send_error(400, message="Invalid flowsheet", explain=str(err))
            return
        self.send_response(200, message="success")

    # === Internal methods ===

    def _write_json(self, code, data):
        str_json = json.dumps(data)
        value = utf8_encode(str_json)
        self.send_response(code)
        # self.send_header("Access-Control-Allow-Headers", "Content-Type")
        self.send_header("Content-type", "application/json")
        self.send_header("Content-length", str(len(value)))
        self.end_headers()
        self.wfile.write(value)

    def _write_html(self, code, page):
        value = utf8_encode(page)
        self.send_response(code)
        self.send_header("Content-type", "text/html")
        self.send_header("Content-length", str(len(value)))
        self.end_headers()
        self.wfile.write(value)

    def _parse_flowsheet_url(self, path):
        u, id_ = urlparse(self.path), None
        if u.query:
            queries = dict([q.split("=") for q in u.query.split("&")])
            id_ = queries.get("id", None)
        return u, id_

    # === Logging ===

    def log_message(self, fmt, *args):
        """Override to send messages to our module logger instead of stderr
        """
        msg = "%s - - [%s] %s" % (
            self.address_string(),
            self.log_date_time_string(),
            fmt % args,
        )
        _log.debug(msg)


def utf8_encode(s: str):
    return s.encode(encoding="utf-8")


def utf8_decode(b: bytes):
    return b.decode(encoding="utf-8")


def find_free_port():
    import time

    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(("127.0.0.1", 0))
    port = s.getsockname()[1]
    s.close()
    time.sleep(1)  # wait for socket cleanup!!!
    return port
