#!/usr/bin/env python3
# PYTHON_ARGCOMPLETE_OK

from __future__ import absolute_import, division, print_function, unicode_literals

import os, sys, subprocess, argparse, base64, logging
from botocore.exceptions import ClientError
from aegea import logger
from aegea.util.aws import (ensure_vpc, ensure_security_group, ensure_ingress_rule, resources, clients,
                            expect_error_codes)

logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser(description="Update a security group to allow instances to access public-facing ELBs")

parser.add_argument("--vpc")
parser.add_argument("--security-group-name", default="aegea_pub_elb_loop")
parser.add_argument("--elb-tcp-ports", nargs="*", default=[80, 443])
parser.add_argument("--elb-name", nargs="+", default=[],
                    help="If given, assign the resulting security group to these ELBs")
args = parser.parse_args()

vpc = resources.ec2.Vpc(args.vpc) if args.vpc else ensure_vpc()
public_ips = set()
for instance in vpc.instances.filter(Filters=[dict(Name="instance-state-name", Values=["running"])]):
    if instance.public_ip_address:
        public_ips.add(instance.public_ip_address)

sg = ensure_security_group(args.security_group_name, vpc)

for rule in sg.ip_permissions:
    if rule.get("IpProtocol") != "tcp":
        logger.warn("The following rule in %s has an unrecognized protocol, skipping: %s", sg, rule)
        continue
    if rule.get("FromPort") not in args.elb_tcp_ports or rule.get("ToPort") not in args.elb_tcp_ports:
        logger.warn("The following rule in %s has an unrecognized port, skipping: %s", sg, rule)
        continue
    for ip_range in rule.get("IpRanges", []):
        if not (ip_range["CidrIp"].endswith("/32") and ip_range["CidrIp"][:-3] in public_ips):
            logger.info("Revoking stale rule for %s", ip_range["CidrIp"])
            sg.revoke_ingress(IpProtocol="tcp", FromPort=rule["FromPort"], ToPort=rule["ToPort"],
                              CidrIp=ip_range["CidrIp"])

for ip in public_ips:
    for port in args.elb_tcp_ports:
        ensure_ingress_rule(sg, IpProtocol="tcp", FromPort=port, ToPort=port, CidrIp=ip + "/32")

for elb_name in args.elb_name:
    try:
        elbs = clients.elbv2.describe_load_balancers(Names=[elb_name])["LoadBalancers"]
        assert len(elbs) == 1
        if sg.id not in elbs[0]["SecurityGroups"]:
            clients.elbv2.set_security_groups(LoadBalancerArn=elbs[0]["LoadBalancerArn"],
                                              SecurityGroups=elbs[0]["SecurityGroups"] + [sg.id])
    except ClientError as e:
        expect_error_codes(e, "LoadBalancerNotFound")
        elbs = clients.elb.describe_load_balancers(LoadBalancerNames=[elb_name])["LoadBalancerDescriptions"]
        assert len(elbs) == 1
        if sg.id not in elbs[0]["SecurityGroups"]:
            clients.elb.apply_security_groups_to_load_balancer(LoadBalancerName=elb_name,
                                                               SecurityGroups=elbs[0]["SecurityGroups"] + [sg.id])
