#!/usr/bin/env python

import argparse
import logging
import os
import signal
import subprocess
import sys

import cicada


def get_environment(world_size, rank, host_addr, link_addr):
    env = os.environ.copy()
    env["WORLD_SIZE"] = str(world_size)
    env["RANK"] = str(rank)
    env["LINK_ADDR"] = str(link_addr)
    env["HOST_ADDR"] = str(host_addr)
    return env


def log_command(command, env, log):
    log.info(f"Command: {' '.join(command)}")
    log.info(f"  Environment:")
    log.info(f"    WORLD_SIZE={env['WORLD_SIZE']}")
    log.info(f"    RANK={env['RANK']}")
    log.info(f"    HOST_ADDR={env['HOST_ADDR']}")
    log.info(f"    LINK_ADDR={env['LINK_ADDR']}")


def basic_frontend(arguments, players, log):
    processes = []
    for world_size, rank, host_addr, link_addr in players:
        env = get_environment(world_size, rank, host_addr, link_addr)

        command = [sys.executable]
        if arguments.inspect:
            command += ["-i"]
        command += [arguments.program]
        command += arguments.args

        log_command(command, env, log)

        if not arguments.dry_run:
            processes.append(subprocess.Popen(command, env=env))

    if not arguments.dry_run:
        signal.signal(signal.SIGINT, signal.SIG_IGN)
        for process in processes:
            process.wait()


def tmux_panes_frontend(arguments, players, log):
    command = []
    for world_size, rank, host_addr, link_addr in players:
        if rank == 0:
            command += ["tmux", "new-session"]
        else:
            command += [";", "split-window", "-v", "-d"]

        command += ["-e", f"WORLD_SIZE={world_size}"]
        command += ["-e", f"RANK={rank}"]
        command += ["-e", f"HOST_ADDR={host_addr}"]
        command += ["-e", f"LINK_ADDR={link_addr}"]
        command += [sys.executable]
        if arguments.inspect:
            command += ["-i"]
        command += [arguments.program]
        command += arguments.args

    command += [";", "select-layout", arguments.tmux_layout]

    log.info(f"Command: {' '.join(command)}")

    processes = []
    if not arguments.dry_run:
        processes.append(subprocess.Popen(command))

    if not arguments.dry_run:
        signal.signal(signal.SIGINT, signal.SIG_IGN)
        for process in processes:
            process.wait()


def tmux_windows_frontend(arguments, players, log):
    command = []
    for world_size, rank, host_addr, link_addr in players:
        if rank == 0:
            command += ["tmux", "new-session"]
        else:
            command += [";", "new-window", "-d"]

        command += ["-n", f"rank-{rank}"]
        command += ["-e", f"WORLD_SIZE={world_size}"]
        command += ["-e", f"RANK={rank}"]
        command += ["-e", f"HOST_ADDR={host_addr}"]
        command += ["-e", f"LINK_ADDR={link_addr}"]
        command += [sys.executable]
        if arguments.inspect:
            command += ["-i"]
        command += [arguments.program]
        command += arguments.args

    log.info(f"Command: {' '.join(command)}")

    processes = []
    if not arguments.dry_run:
        processes.append(subprocess.Popen(command))

    if not arguments.dry_run:
        signal.signal(signal.SIGINT, signal.SIG_IGN)
        for process in processes:
            process.wait()


def xterm_frontend(arguments, players, log):
    processes = []
    for world_size, rank, host_addr, link_addr in players:
        env = get_environment(world_size, rank, host_addr, link_addr)

        command = ["xterm", "-e"]
        command += [sys.executable]
        if arguments.inspect:
            command += ["-i"]
        command += [arguments.program]
        command += arguments.args

        log_command(command, env, log)

        if not arguments.dry_run:
            processes.append(subprocess.Popen(command, env=env))

    if not arguments.dry_run:
        signal.signal(signal.SIGINT, signal.SIG_IGN)
        for process in processes:
            process.wait()


frontends = {
    "basic": basic_frontend,
    "tmux": tmux_panes_frontend,
    "tmux-windows": tmux_windows_frontend,
    "xterm": xterm_frontend,
}


parser = argparse.ArgumentParser(description="Cicada MPC tools.")
subparsers = parser.add_subparsers(title="commands (choose one)", dest="command")

# run
subparser = subparsers.add_parser("run", help="Run all Cicada processes on the local machine.")
subparser.add_argument("--bind-public", action="store_true", help="Use a public host address.")
subparser.add_argument("--dry-run", "-y", action="store_true", help="Don't start actual processes.")
subparser.add_argument("--frontend", "-f", choices=frontends.keys(), default="basic", help="Frontend to execute processes.")
subparser.add_argument("--host-addr", default="127.0.0.1", help="Host address. Default: %(default)s")
subparser.add_argument("--inspect", "-i", action="store_true", help="Start a Python prompt after running program.")
subparser.add_argument("--link-port", type=int, default=25252, help="Link port.")
subparser.add_argument("--tmux-layout", default="even-vertical", choices=["even-horizontal", "even-vertical", "tiled"], help="Pane layout for the tmux frontend. Default: %(default)s")
subparser.add_argument("--world-size", "-n", type=int, default=3, help="Number of players. Default: %(default)s")
subparser.add_argument("program", help="Program to execute.")
subparser.add_argument("args", nargs=argparse.REMAINDER, help="Program arguments.")

# start
subparser = subparsers.add_parser("start", help="Start one Cicada process.")
subparser.add_argument("--bind-public", action="store_true", help="Use a public host address.")
subparser.add_argument("--dry-run", "-y", action="store_true", help="Don't start actual processes.")
subparser.add_argument("--frontend", "-f", choices=frontends.keys(), default="basic", help="Frontend to execute processes.")
subparser.add_argument("--host-addr", default="127.0.0.1", help="Host address. Default: %(default)s")
subparser.add_argument("--host-port", type=int, default=None, help="Host port. Default: randomly chosen port")
subparser.add_argument("--inspect", "-i", action="store_true", help="Start a Python prompt after running program.")
subparser.add_argument("--link-addr", default=None, help="Link address. Default: 127.0.0.1")
subparser.add_argument("--link-port", type=int, default=None, help="Link port.")
subparser.add_argument("--rank", type=int, required=True, help="Player rank.")
subparser.add_argument("--world-size", "-n", type=int, default=3, help="Number of players. Default: %(default)s")
subparser.add_argument("program", help="Program to execute.")
subparser.add_argument("args", nargs=argparse.REMAINDER, help="Program arguments.")

# version
subparser = subparsers.add_parser("version", help="Print the Cicada version.")


if __name__ == "__main__":
    arguments = parser.parse_args()

    if arguments.command is None:
        parser.print_help()

    logging.basicConfig(level=logging.INFO)
    log = logging.getLogger()
    log.name = os.path.basename(sys.argv[0])

    # run
    if arguments.command == "run":
        host = cicada.bind.public_ip() if arguments.bind_public else arguments.host_addr

        host_addrs = []
        for rank in range(arguments.world_size):
            host_addr = f"tcp://{host}:{arguments.link_port}" if rank == 0 else f"tcp://{host}"
            host_addrs.append(host_addr)

        players = []
        for rank, host_addr in enumerate(host_addrs):
            players.append((arguments.world_size, rank, host_addr, host_addrs[0]))

        frontend = frontends[arguments.frontend]
        frontend(arguments, players, log)

    # start
    if arguments.command == "start":
        world_size = arguments.world_size
        if world_size < 1:
            raise RuntimeError("--world-size must be greater than zero.")

        rank = arguments.rank

        if rank >= world_size:
            raise RuntimeError("--rank must be less than --world-size.")

        if rank == 0:
            if arguments.host_port is None:
                raise RuntimeError("--host-port must be specified when --rank is 0.")
            if arguments.link_addr is not None:
                raise RuntimeError("--link-addr cannot be specified when --rank is 0.")
            if arguments.link_port is not None:
                raise RuntimeError("--link-port cannot be specified when --rank is 0.")

            addr = cicada.bind.public_ip() if arguments.bind_public else arguments.host_addr
            port = arguments.host_port
            host_addr = f"tcp://{addr}:{port}"
            link_addr = host_addr
        else:
            if arguments.link_port is None:
                raise RuntimeError("--link-port must be specified when --rank is not 0.")

            addr = cicada.bind.public_ip() if arguments.bind_public else arguments.host_addr
            port = arguments.host_port
            if port is None:
                host_addr = f"tcp://{addr}"
            else:
                host_addr = f"tcp://{addr}:{port}"

            addr = "127.0.0.1" if arguments.link_addr is None else arguments.link_addr
            port = arguments.link_port
            link_addr = f"tcp://{addr}:{port}"

        players = [(world_size, rank, host_addr, link_addr)]
        frontend = frontends[arguments.frontend]
        frontend(arguments, players, log)

    # version
    if arguments.command == "version":
        print(cicada.__version__)


