#! /usr/bin/env python
#
# This is open-source software licensed under a BSD license.
# Please see the file LICENSE.txt for details.
#
import pickle

from autobahn.twisted.component import Component, run
from autobahn.twisted.wamp import ApplicationSession
from influxdb import InfluxDBClient
from twisted.internet import threads
from twisted.internet.defer import inlineCallbacks
from twisted.logger import Logger
from twisted.internet.task import LoopingCall

HEARTBEAT_DELAY = 15


class TelemetryLogger(ApplicationSession):

    log = Logger()

    def onJoin(self, details):
        print("session ready")
        self.topic_callbacks = {
            'hipercam.ccd1.telemetry': self.log_ccd1_telemetry,
            'hipercam.ccd2.telemetry': self.log_ccd2_telemetry,
            'hipercam.ccd3.telemetry': self.log_ccd3_telemetry,
            'hipercam.ccd4.telemetry': self.log_ccd4_telemetry,
            'hipercam.ccd5.telemetry': self.log_ccd5_telemetry,
            'hipercam.slide.telemetry': self.log_slide_telemetry,
            'hipercam.compo.telemetry': self.log_compo_telemetry,
        }
        self._client = InfluxDBClient(host='pulsar.shef.ac.uk', port=8086)
        self._client.switch_database('hipercam')
        self._points = []  # list of dictionaries containing data to write
        self._bulk_limit = 10  # a write is triggered once we have this many points
        self._retries_remaining = 10  # will attempt to write data numerous times before giving up
        # subscribe to telemetry topics
        for topic in self.topic_callbacks:
            callback = self.topic_callbacks[topic]
            self.subscribe(callback, topic)

        self._tick_no = 0
        self._tick_loop = LoopingCall(self._tick)
        self._tick_loop.start(HEARTBEAT_DELAY)

    def _tick(self):
        self._tick_no += 1
        self.log.info('hwlogger is alive [tick {}]'.format(
            self._tick_no))

    def update_db(self):
        """
        Update influx DB
        """
        try:
            result = self._client.write_points(self._points, time_precision='ms')
        except Exception as err:
            self.log.warn(f'failed writing batch of data (will retry): {err}')
            result = False

        self.log.debug('updated DB')
        if result:
            self._points = []
            self._retries_remaining = 10
        elif self._retries_remaining > 0:
            self._retries_remaining -= 1
        else:
            self.points = []
            self.retries_remainining = 10
            self.log.error('could not write data to DB, giving up. this batch of data is lost!')

    @inlineCallbacks
    def log_telemetry(self, point):
        """
        asynchronous database update
        """
        # if we haven't reached threshold for update, add to queue
        if len(self._points) < self._bulk_limit:
            self._points.append(point)
            return

        # we've reached bulk threshold, update_db
        try:
            yield threads.deferToThread(self.update_db)
        except Exception as e:
            self.log.error(str(e))

    def preprocess_telemetry(self, data):
        telemetry = pickle.loads(data)
        ts = telemetry.pop('timestamp').iso
        return ts, telemetry

    def process_field(self, field):
        if hasattr(field, 'value'):
            return field.value
        return field

    def make_point(self, data, measurement):
        point = {}
        point['measurement'] = measurement
        ts, telemetry = self.preprocess_telemetry(data)
        point['time'] = ts
        # set fields to telemetry
        point['fields'] = {k: self.process_field(v) for (k, v) in telemetry.items()}
        return point

    def log_ccd_telemetry(self, ccd, data):
        try:
            point = self.make_point(data, ccd)
        except Exception as err:
            self.log.error(f'cannot process telemetry for {ccd}: {err}')
        else:
            self.log_telemetry(point)

    def log_ccd1_telemetry(self, data):
        self.log_ccd_telemetry('ccd1', data)

    def log_ccd2_telemetry(self, data):
        self.log_ccd_telemetry('ccd2', data)

    def log_ccd3_telemetry(self, data):
        self.log_ccd_telemetry('ccd3', data)

    def log_ccd4_telemetry(self, data):
        self.log_ccd_telemetry('ccd4', data)

    def log_ccd5_telemetry(self, data):
        self.log_ccd_telemetry('ccd5', data)

    def log_compo_telemetry(self, data):
        """
        We don't currently log COMPO telemetry
        """
        pass

    def log_slide_telemetry(self, data):
        """
        We don't currently log slide telemetry
        """
        pass


if __name__ == "__main__":

    comp = Component(transports='ws://192.168.1.2:8080/ws', realm='realm1',
                     session_factory=TelemetryLogger)
    run([comp])
