#!/usr/bin/env python

# releasemon - Supervisor plugin to watch a version file and restart
#              program on update
#
#              Based on an idea and script by Tim Schumacher (superfsmon)
#
# Copyright (C) 2019 James Cundle
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function

import os
import sys
import time
import argparse
import re
import signal
import threading
import datetime
import random
import re

try:
    import xmlrpc.client as xmlrpclib
except ImportError:
    import xmlrpclib

from supervisor import childutils
from watchdog.observers import Observer
from watchdog.events import (PatternMatchingEventHandler)


parser = argparse.ArgumentParser(
        description='Supervisor plugin to watch a directory and restart '
                    'programs on changes')

parser.add_argument('-e', '--enable', metavar='FLAG', type=int,
        help='disable functionality if flag is not set')
parser.add_argument('--disable', metavar='FLAG', nargs='?', type=int, const=1,
        help='disable functionality if flag is set ')
parser.add_argument('-rd', '--random-delay', type=int, default=0,
        dest='random_delay',
        help='random delay time in second before running the restart')
parser.add_argument('-fd', '--fixed-delay', type=int, default=0,
        dest='fixed_delay',
        help='fixed delay time in second before running the restart')
parser.add_argument('--running-only',
        dest='running_only', action='store_true',
        help="Restart RUNNING processes only")

task_group = parser.add_argument_group('additional tasks')
task_group.add_argument('-pre', '--pre-restart',
        dest='pre_restart', type=str, default='',
        help='script to run pre-restarting')
task_group.add_argument('-post', '--post-restart',
        dest='post_restart', action='store_false', default='',
        help='script to run post-restarting')

monitor_group = parser.add_argument_group('directory monitoring')
monitor_group.add_argument('version_file',
        help='version file to watch for changes')

program_group = parser.add_argument_group('programs')
program_group.add_argument('program', metavar='PROG', nargs='*',
        help="supervisor program name to restart")
program_group.add_argument('-g', '--group', action='append', default=[],
        help='supervisor group name to restart')
program_group.add_argument('-a', '--any', action='store_true',
        help='restart any child of this supervisor')

pre_restarting_lock = threading.Lock()
restarting_lock = threading.Lock()

current_version = ''

def info(msg, file=sys.stdout):
    now = datetime.datetime.now()
    datestr = now.strftime("%Y-%m-%d %H:%M:%S")
    print('%s - releasemon: %s' % (datestr, msg), file=file)
    file.flush()


def error(msg, status=1):
    now = datetime.datetime.now()
    datestr = now.strftime("%Y-%m-%d %H:%M:%S")
    info('%s - error: %s' % (datestr, msg), file=sys.stderr)
    sys.exit(status)


def usage_error(msg):
    parser.print_usage(file=sys.stderr)
    error(msg, status=2)

def validate_args(args):
    if args.enable is not None and args.disable is not None:
            usage_error('argument --enable not allowed with --disable')

def handle_term(signum, frame):
    info('terminating')
    try:
        observer.stop()
    except NameError:
        sys.exit()

def get_version():
    filename = args.version_file
    content = ''
    if os.path.exists(filename):
        fp = open(filename, "r")
        content = fp.read()
        fp.close()
    else:
        info('Build file not found at: ' + filename)

    return content.strip()

def requires_restart(proc):
    name = proc['name']
    group = proc['group']
    statename = proc['statename']
    pid = proc['pid']

    if pid == os.getpid():
       return False

    if args.running_only:
       state_ok = (statename == 'RUNNING')
       if not state_ok:
          info('Process state is not RUNNING for ' + name + '. Skipping.')
    else:
       state_ok = (statename == 'STARTING' or statename == 'RUNNING')

    # Trip job identifier off name
    subname = re.sub('_[0-9]+$', '', name)

    return ((state_ok) and
            (args.any or subname in args.program or group in args.group) and
            pid != os.getpid())

def has_version_changed():
    global current_version
    new_version = get_version()

    info('Current version: ' + current_version);
    info('New version: ' + new_version);

    return current_version != new_version

def run_script(prefix, cmd):
    output = os.popen(cmd).read()
    info(prefix + "\n" + output.strip())

def restart_programs():
    global current_version

    if not has_version_changed():
        info("Version unchanged. Refusing to restart.")
        return False
    else:
        info("Version change detected")

    procs = rpc.supervisor.getAllProcessInfo()
    restart_names = [proc['group'] + ':' + proc['name']
                     for proc in procs if requires_restart(proc)]

    has_restart_tasks = len(restart_names)>0

    if not has_restart_tasks:
        info('Nothing to restart.')
        if args.running_only:
            info('*WARNING* As you specified RUNNING only, your running application may now be different to your version file.')

    random_time=0
    if args.random_delay>0 and len(restart_names)>0:
        random_time = random.randrange(0,args.random_delay)
        info('Random delay of ' + str(random_time) + ' allocated')

    total_delay = args.fixed_delay + random_time
    if total_delay > 0:
        info('Total delay of ' + str(total_delay) + ' specified. Sleeping ' + str(total_delay) + ' seconds.')
        time.sleep(total_delay)


    for name in list(restart_names):
        info("Stopping " + name)
        try:
            rpc.supervisor.stopProcess(name, False)
        except xmlrpclib.Fault as exc:
            info('warning: failed to stop process: ' + exc.faultString)
            restart_names.remove(name)


    if has_restart_tasks and args.pre_restart:
       info('Executing pre-restart script')
       run_script('Output: ', args.pre_restart)

    while restart_names:
        for name in list(restart_names):
            proc = rpc.supervisor.getProcessInfo(name)
            if proc['statename'] != 'STOPPED':
                continue
            try:
                info("Starting " + name)
                rpc.supervisor.startProcess(name, False)
                restart_names.remove(name)
            except xmlrpclib.Fault as exc:
                info('warning: failed to start process: ' + exc.faultString)
                restart_names.remove(name)
        time.sleep(0.1)

    if has_restart_tasks and args.post_restart:
       info('Executing post-restart script')
       run_script('Output: ', args.post_restart)

    if has_restart_tasks:
        current_version = get_version()
        info("Updated current version to " + current_version)


def commence_restart():
    if not pre_restarting_lock.acquire(False):
        return
    # info('detected change, commencing restart of programs')
    info('Detected file change.')
    time.sleep(0.1)
    restarting_lock.acquire()
    pre_restarting_lock.release()
    restart_programs()
    restarting_lock.release()


class RestartEventHandler(object):

    def on_any_event(self, event):
        thread = threading.Thread(target=commence_restart)
        thread.start()


class RestartPatternMatchingEventHandler(RestartEventHandler,
                                         PatternMatchingEventHandler):
    pass



if __name__ == '__main__':
    args = parser.parse_args()
    validate_args(args)

    current_version = get_version()

    if args.enable == 0 or args.disable:
        info('functionality disabled, waiting for termination signal')
        signal.signal(signal.SIGINT,  handle_term)
        signal.signal(signal.SIGTERM, handle_term)
        signal.pause()


    event_handler = RestartPatternMatchingEventHandler(
                patterns= [args.version_file],
                case_sensitive=False)

    path = os.path.dirname(args.version_file)

    observer = Observer()
    observer.schedule(event_handler, path, recursive=False)

    try:
        rpc = childutils.getRPCInterface(os.environ)
    except KeyError as exc:
        error('missing environment variable ' + str(exc))

    info('Watching ' + args.version_file)

    info('Current Version: ' + current_version)

    if args.fixed_delay:
        info('Fixed delay set set to ' + str(args.fixed_delay))

    if args.random_delay:
        info('Random delay set between 0 and ' + str(args.random_delay))

    if args.pre_restart:
        info('Pre-restart script set to "' + args.pre_restart + '"')

    if args.post_restart:
        info('Post-restart script set to "' + args.post_restart + '"')

    try:
        observer.start()
    except OSError as exc:
        error(str(exc))

    signal.signal(signal.SIGINT,  handle_term)
    signal.signal(signal.SIGTERM, handle_term)

    while observer.is_alive():
        observer.join(1)
