#!/usr/bin/python

# Copyright (C) 2017  DESY, Notkestr. 85, D-22607 Hamburg
#
# lavue is an image viewing program for photon science imaging detectors.
# Its usual application is as a live viewer using hidra as data source.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation in  version 2
# of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor,
# Boston, MA  02110-1301, USA.
#
# Authors:
#     Jan Kotanski <jan.kotanski@desy.de>
#     Christoph Rosemann <christoph.rosemann@desy.de>
#
# Socket to talk to server

import sys
import zmq
import time
import json
import math
import argparse
import os
import signal

maxtimegap = 1.
maximumval = 1000.
port = None
topicfilter = "10001"
debug = False
command = ""
rawimage = False
lasttime = time.time()
hostname = "localhost"

context = None

original_sigint = signal.getsignal(signal.SIGINT)


def _onexit(signum, frame):
    global context
    if context:
        try:
            context.destroy()
            context = None
            print("disconnect")
        except:
            pass
    signal.signal(signal.SIGINT, original_sigint)
    sys.exit(1)


def stophardware():
    if command:
        print("-----------------------------------------------------------")
        print("execute: '%s'" % str(command))
        os.system(command)
    print("===========================================================")


def stop(pid, calctime, **args):
    stophardware()
    print("STOP !!! %s LiveView stopped" % calctime)


def start(pid, calctime, **args):
    if debug:
        print("  START: %s" % calctime)


def alive(pid, calctime, maxrawval, maxval, minval, scaling,  meanval, **args):
    global lasttime
    if rawimage:
        val = float(maxrawval)
    else:
        val = float(maxval)
        if scaling == "sqrt":
            val = val * val
        elif scaling == "log":
            val = math.pow(10, val)
        elif scaling != "linear":
            raise Exception("Unknown scaling: %s" % scaling)
    if debug:
        print("  ELIVE !!! %s %s %s %s %s" % (
            calctime, scaling, maxval, maxrawval, val))
    if val > maximumval:
        print("STOP !!! %s maxval = %s > %s" % (calctime, val, maximumval))
        stophardware()
    lasttime = float(calctime)


def checktime(calctime):
    ltime = time.time()
    # if debug:
    #    print("GAP %s %s" % (ltime - calctime, maxtimegap))
    if calctime + maxtimegap < ltime:
        print("STOP !!! %s Time GAP = %s > %s" % (
             calctime, ltime - calctime, maxtimegap))
        stophardware()


def main():
    global lasttime
    global context
    signal.signal(signal.SIGINT, _onexit)
    context = zmq.Context()
    socket = context.socket(zmq.SUB)
    conn = "tcp://%s:%s" % (hostname, port)
    print("Connecting to: %s" % conn)
    socket.connect(conn)

    socket.setsockopt(zmq.SUBSCRIBE, topicfilter)

    receiveloop = True
    while receiveloop:
        try:
            while True:
                # check for a message, this will not block
                string = socket.recv(flags=zmq.NOBLOCK)
                topic, message = string.split(" ", 1)
                md = json.loads(message)
                cmd = md["command"].lower()
                if cmd == 'stop':
                    stop(**md)
                elif cmd == 'alive':
                    alive(**md)
                elif cmd == 'start':
                    start(**md)
                else:
                    raise Exception("Wrong command")

                checktime(float(md["calctime"]))
                time.sleep(maxtimegap * 0.001)

        except zmq.Again as e:
            pass
        except Exception as e:
            print("STOP !!! Error: %s" % str(e))
            stophardware()

        checktime(float(lasttime))
        time.sleep(maxtimegap * 0.1)


if __name__ == "__main__":
    options = None
    parser = argparse.ArgumentParser(
        description='ZMQ Client for laVue status')
    parser.add_argument(
        "-i", "--max-intensity",
        help="maximal pixel value (default: 1000.)",
        dest="maxval", default="1000.")
    parser.add_argument(
        "-g", "--time-gap",
        help="maximal time gap in seconds (default: 1.)",
        dest="timegap", default="1.")
    parser.add_argument(
        "-c", "--stop-command",
        help="stop command",
        dest="command", default="")
    parser.add_argument(
        "-r", "--raw", action="store_true",
        default=False, dest="raw",
        help="check raw image")
    parser.add_argument(
        "-p", "--port",
        help="zmq port (default: automatic)",
        dest="port", default=None)
    parser.add_argument(
        "-z", "--host",
        help="zmq host (default: localhost)",
        dest="host", default="localhost")
    parser.add_argument(
        "-t", "--topic",
        help="zmq topic (default: 10001)",
        dest="topic", default="10001")
    parser.add_argument(
        "--debug", action="store_true",
        default=False, dest="debug",
        help="debug mode")
    options = parser.parse_args()

    try:
        port = int(options.port)
    except:
        try:
            import psutil
            pids = [process.pid for process in psutil.process_iter()
                    if process.name() == 'lavue']
            ports = []
            if len(pids) > 0:
                for pd in pids:
                    proc = psutil.Process(pd)
                    cnns = proc.connections()
                    for cn in cnns:
                        if cn.status == 'LISTEN':
                            hostname, lport = cn.laddr
                            ports.append(str(lport))
                if len(ports) == 1:
                    port = ports[0]
                elif len(ports) > 1:
                    sys.stderr.write(
                        "lavuemonitor: Select one of the ports: %s\n" %
                        ", ".join(ports))
                    sys.stderr.flush()
                    parser.print_help()
                    sys.exit(255)
                else:
                    raise Exception("Cannot find the lavue port")
            else:
                raise Exception("Cannot find the lavue port")
        except Exception as e:
            sys.stderr.write(
                "lavuemonitor: Invalid --port parameter\n")
            sys.stderr.flush()
            # print(str(e))
            parser.print_help()
            sys.exit(255)
    try:
        maximumval = float(options.maxval)
    except:
        sys.stderr.write(
            "lavuemonitor: Invalid --maxval parameter\n")
        sys.stderr.flush()
        parser.print_help()
        sys.exit(255)
    try:
        topicfilter = str(options.topic)
    except:
        sys.stderr.write(
            "lavuemonitor: Invalid --topic parameter\n")
        sys.stderr.flush()
        parser.print_help()
        sys.exit(255)
    try:
        command = str(options.command)
    except:
        sys.stderr.write(
            "lavuemonitor: Invalid --stop-command parameter\n")
        sys.stderr.flush()
        parser.print_help()
        sys.exit(255)
    try:
        maxtimegap = float(options.timegap)
    except:
        sys.stderr.write(
            "lavuemonitor: Invalid --time-gap parameter\n")
        sys.stderr.flush()
        parser.print_help()
        sys.exit(255)
    debug = options.debug
    rawimage = options.raw
    hostname = options.host

    main()
