#!/bin/env /usr/python

import argparse
import logging
import os
import signal
import subprocess
import sys

import cicada.bind

parser = argparse.ArgumentParser(description="Execute a Cicada MPC program.")
parser.add_argument("--bind-public", action="store_true", help="Bind sockets to public IP addresses instead of the loopback adapter.")
parser.add_argument("--dry-run", "-y", action="store_true", help="Log startup commands without running them.")
parser.add_argument("--host", default=None, help="Explicit host address.  Default: %(default)s")
parser.add_argument("--host-port", type=int, default=None, help="Explicit host port.  Default: choose a port at random.")
parser.add_argument("--interactive", "-i", action="store_true", help="Enter interactive mode after the program completes.")
parser.add_argument("--player", type=int, default=None, help="Start one player.  Default: %(default)s")
parser.add_argument("--players", "-n", type=int, default=3, help="Number of players.  Default: %(default)s")
parser.add_argument("--link", default=None, help="Link address.  Default: %(default)s")
parser.add_argument("--link-port", type=int, default=None, help="Link port.  Default: choose a port at random.")
parser.add_argument("--xterm", "-X", action="store_true", help="Open processes in separate xterm windows")
parser.add_argument("program", help="Python script file to execute.")
parser.add_argument("args", nargs=argparse.REMAINDER)

if __name__ == "__main__":
    arguments = parser.parse_args()

    logging.basicConfig(level=logging.INFO)
    log = logging.getLogger()
    log.name = os.path.basename(sys.argv[0])

    log.warning(f"'cicada-exec' is deprecated and will be removed in a future release. Use the 'cicada' command instead.\n")

    def get_addr(*, addr=None, port=None, public=False):
        if addr is None:
            addr = cicada.bind.public_ip() if public else cicada.bind.loopback_ip()
        if port is None:
            port = cicada.bind.random_port(addr)
        return f"tcp://{addr}:{port}"

    # Create a record for every player to be started.
    players = []

    if arguments.player is None: # Run every player on the local machine.
        addresses = []
        for rank in range(arguments.players):
            addresses.append(get_addr())
        for rank, host_addr in enumerate(addresses):
            players.append((arguments.players, rank, addresses[0], host_addr))

    else: # Run one player.
        if arguments.player == 0:
            host_addr = get_addr(addr=arguments.host, port=arguments.host_port, public=arguments.bind_public)
            players.append((arguments.players, 0, host_addr, host_addr))
        else:
            link_addr = get_addr(addr=arguments.link, port=arguments.link_port, public=arguments.bind_public)
            host_addr = get_addr(addr=arguments.host, port=arguments.host_port, public=arguments.bind_public)
            players.append((arguments.players, arguments.player, link_addr, host_addr))


    # Start player processes.
    processes = []
    for world_size, rank, link_addr, host_addr in players:
        env = os.environ.copy()
        env["WORLD_SIZE"] = str(world_size)
        env["RANK"] = str(rank)
        env["LINK_ADDR"] = link_addr
        env["HOST_ADDR"] = host_addr

        command = []
        if arguments.xterm:
            command += ["xterm", "-e"]
        command += [sys.executable]
        if arguments.interactive:
            command += ["-i"]
        command += [arguments.program]
        command += arguments.args

        log.info(f"Command: {' '.join(command)}")
        log.info(f"  Environment:")
        log.info(f"    WORLD_SIZE={env['WORLD_SIZE']}")
        log.info(f"    RANK={env['RANK']}")
        log.info(f"    LINK_ADDR={env['LINK_ADDR']}")
        log.info(f"    HOST_ADDR={env['HOST_ADDR']}")

        if not arguments.dry_run:
            processes.append(subprocess.Popen(command, env=env))

    # Wait for player processes to finish.
    if not arguments.dry_run:
        signal.signal(signal.SIGINT, signal.SIG_IGN)
        for process in processes:
            process.wait()
