#!/usr/bin/env python3

"""
Helper utility that lists and mounts AWS EFS filesystems.
"""

import os, sys, subprocess, argparse, logging, collections

import boto3, requests

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--timeout", type=int, default=10)
args = parser.parse_args()

def get_metadata(category):
    return subprocess.check_output(["aegea-imds", category]).decode().strip()

az = get_metadata("placement/availability-zone")
services_domain = get_metadata("services/domain")
subnet_id = get_metadata("network/interfaces/macs/{}/subnet-id".format(get_metadata("mac")))

# session
sts = boto3.client('sts')
efs = boto3.Session(region_name=az[:-1]).client("efs")

mountpoints = {}

for filesystem in efs.describe_file_systems()["FileSystems"]:
    mountpoint = next((_["Value"] for _ in filesystem["Tags"] if _["Key"] == 'mountpoint'), None)

    def accessible():
        for mt in efs.describe_mount_targets(FileSystemId=filesystem["FileSystemId"])["MountTargets"]:
            if mt["SubnetId"] == subnet_id:
                return True
    if mountpoint and accessible():
        assert mountpoint not in mountpoints
        mountpoints[mountpoint] = filesystem
        logger.info("Using %s for EFS %s", filesystem["FileSystemId"], mountpoint)

if not mountpoints:
    logger.info("No EFS filesystems found")
    sys.exit(-1)

mount_procs = collections.OrderedDict()
for mountpoint, filesystem in mountpoints.items():
    logger.info("Mounting %s on %s", filesystem["FileSystemId"], mountpoint)
    fs_url = "{az}.{fs}.efs.{region}.{domain}:/".format(az=az,
                                                        fs=filesystem["FileSystemId"],
                                                        region=efs.meta.region_name,
                                                        domain=services_domain)

    if not os.path.isdir(mountpoint):
        os.makedirs(mountpoint)
    mount_procs[mountpoint] = subprocess.Popen(["mount", "-t", "nfs4", "-o",
                                                "nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2",
                                                fs_url, mountpoint])

rvals = collections.defaultdict(int)
for mountpoint, proc in mount_procs.items():
    try:
        rval = proc.wait(timeout=args.timeout)
        msg = "Path: %s rval: %s ismount: %s timeout: %d"
        logger.debug(msg, mountpoint, rval, os.path.ismount(mountpoint), args.timeout)
        if not os.path.ismount(mountpoint):
            logger.info("Mount exited with status: %d" % rvals[mountpoint])
            rvals[mountpoint] = rval
    except Exception:
        rvals[mountpoint] = 100
        logger.exception("Error while mounting EFS filesystem!")
        try:
            proc.terminate()
        except Exception:
            logger.exception("Caught exception terminating")
            pass

sys.exit(sum(rvals.values()))
