#!/usr/bin/python

# Copyright (C) 2017  DESY, Notkestr. 85, D-22607 Hamburg
#
# lavue is an image viewing program for photon science imaging detectors.
# Its usual application is as a live viewer using hidra as data source.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation in  version 2
# of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor,
# Boston, MA  02110-1301, USA.
#
# Authors:
#     Jan Kotanski <jan.kotanski@desy.de>
#     Christoph Rosemann <christoph.rosemann@desy.de>
#
# Socket to talk to server

import sys
import zmq
import time
import argparse
import signal
import json
import numpy as np
import random

maxtimegap = 0.1
port = None
topicfilter = None
debug = False
attribute = ""
hostname = "localhost"
prefix = None
nodict = False

context = None

original_sigint = signal.getsignal(signal.SIGINT)


def _onexit(signum, frame):
    global context
    if context:
        try:
            context.destroy()
            context = None
            print("disconnect")
        except:
            pass
    signal.signal(signal.SIGINT, original_sigint)
    sys.exit(1)


def main():
    global lasttime
    global context
    signal.signal(signal.SIGINT, _onexit)
    context = zmq.Context()
    socket = context.socket(zmq.PUB)
    conn = "tcp://*:%s" % (port)
    print("Connecting to: %s" % conn)
    socket.bind(conn)
    counter = 0

    receiveloop = True
    while receiveloop:
        try:
            value = np.transpose(
                [
                    [random.randint(0, 1000) for _ in range(512)]
                    for _ in range(256)
                ])
            shape = value.shape
            dtype = value.dtype.name

            datasources = ["%s" % (10010 + counter // 50), "10001", "10002"]
            if topicfilter is not None:
                tfilter = topicfilter
            else:
                tfilter = datasources[counter % 2]
            if not counter % 10:
                axisscales = [
                    float(counter), float(counter),
                    (counter % 5) + 1.0, (counter % 5) + 2.0]
            else:
                axisscales = None
            if not counter % 20:
                axislabels = [
                    "x%s" % counter, "y%s" % counter,
                    "m" * ((counter // 20) % 3), "A" * ((counter // 20) % 3)]
            else:
                axislabels = None
            if nodict:
                if prefix:
                    message = (
                        tfilter,
                        value,
                        json.dumps(shape).encode('ascii', 'ignore'),
                        json.dumps(dtype).encode('ascii', 'ignore'),
                        "%s_%s" % (prefix, counter),
                        "JSON".encode('ascii', 'ignore')
                    )
                else:
                    message = (
                        tfilter.encode('ascii', 'ignore'),
                        value,
                        json.dumps(shape).encode('ascii', 'ignore'),
                        json.dumps(dtype).encode('ascii', 'ignore'),
                        "JSON".encode('ascii', 'ignore')
                    )
            else:
                metadata = {"shape": shape, "dtype": dtype,

                            "datasources": datasources}
                if prefix:
                    metadata["name"] = "%s_%s" % (prefix, counter)
                if axislabels is not None:
                    metadata["axislabels"] = axislabels
                if axisscales is not None:
                    metadata["axisscales"] = axisscales

                message = (
                    tfilter.encode('ascii', 'ignore'),
                    value,
                    json.dumps(metadata).encode('ascii', 'ignore'),
                    "JSON".encode('ascii', 'ignore')
                )
                metadata2 = dict(metadata)
                if counter % 3:
                    metadata2.pop("shape")
                    metadata2.pop("dtype")
                message2 = (
                    "datasources".encode('ascii', 'ignore'),
                    json.dumps(metadata2).encode('ascii', 'ignore'),
                    "JSON".encode('ascii', 'ignore')
                )
                print("Send2: tcp://*:%s/%s %s %s"
                      % (port, "datasources", metadata2, "JSON"))
                socket.send_multipart(message2)
            counter += 1
            socket.send_multipart(message)
            print("Send: tcp://*:%s/%s %s %s %s_%s %s"
                  % (port, tfilter, shape, dtype,
                     prefix or "", counter, "JSON"))
        except Exception as e:
            print("Error: %s" % str(e))
        time.sleep(maxtimegap)


if __name__ == "__main__":
    options = None
    parser = argparse.ArgumentParser(
        description='ZMQ JSON test server')
    parser.add_argument(
        "-g", "--time-gap",
        help="maximal time gap in seconds (default: 0.1)",
        dest="timegap", default="0.1")
    parser.add_argument(
        "-p", "--port",
        help="zmq port (default: automatic)",
        dest="port", default=None)
    parser.add_argument(
        "-t", "--topic",
        help="zmq topic (default: first one from datasources)",
        dest="topic", default=None)
    parser.add_argument(
        "-n", "--name-prefix",
        help="image name prefix",
        dest="prefix", default=None)
    parser.add_argument(
        "--no-dict", action="store_true",
        default=False, dest="nodict",
        help="create zmq stream without dictionary")
    parser.add_argument(
        "--debug", action="store_true",
        default=False, dest="debug",
        help="debug mode")
    options = parser.parse_args()

    try:
        port = int(options.port)
    except Exception as e:
        # print(str(e))
        sys.stderr.write("lavuemonitor: Invalid --port parameter\n")
        sys.stderr.flush()
        parser.print_help()
        sys.exit(255)
    try:
        if options.topic is not None:
            topicfilter = str(options.topic)
    except:
        pass
    try:
        maxtimegap = float(options.timegap)
    except:
        sys.stderr.write("lavuemonitor: Invalid --time-gap parameter\n")
        sys.stderr.flush()
        parser.print_help()
        sys.exit(255)
    debug = options.debug
    prefix = options.prefix
    nodict = options.nodict

    main()
