#!/usr/bin/env python3

import random
import string
import subprocess
import argparse
import os
import stat
import boto3

from cryptography.hazmat.primitives import serialization as crypto_serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend as crypto_default_backend

# Clients
ec2_client = boto3.client('ec2')
ec2_connect_client = boto3.client('ec2-instance-connect')

# Argument Parser
parser = argparse.ArgumentParser("aws-ic")
parser.add_argument("ssh_user", help="User to establish SSH connection with.", type=str)
parser.add_argument("instance_id", help="Instance ID to connect to.", type=str)
args = parser.parse_args()


def random_string(string_length=10):
    """Generate a random string of fixed length """
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for _ in range(string_length))


# Generate temporary key pairs
key = rsa.generate_private_key(
    backend=crypto_default_backend(),
    public_exponent=65537,
    key_size=2048
)

private_key = key.private_bytes(
    crypto_serialization.Encoding.PEM,
    crypto_serialization.PrivateFormat.PKCS8,
    crypto_serialization.NoEncryption()
)

public_key = key.public_key().public_bytes(
    crypto_serialization.Encoding.OpenSSH,
    crypto_serialization.PublicFormat.OpenSSH
)

# Write private key to file and set POSIX permissions properly
filename = "/tmp/" + random_string()
pri = open(filename, "w")
pri.write(private_key.decode('utf8'))
pri.close()
os.chmod(filename, stat.S_IREAD)

# Get instance info
response = ec2_client.describe_instances(InstanceIds=[args.instance_id])
instance = response['Reservations'][0]['Instances'][0]
az = instance['Placement']['AvailabilityZone']
# If no public IP, use private IP
if not instance['PublicIpAddress']:
    ip = instance['PrivateIpAddress']
else:
    ip = instance['PublicIpAddress']

# Send public key to instance
ec2_connect = ec2_connect_client.send_ssh_public_key(
    InstanceId=args.instance_id,
    InstanceOSUser=args.ssh_user,
    SSHPublicKey=public_key.decode('utf8'),
    AvailabilityZone=az
)

if ec2_connect['Success'] is False:
    print("EC2 Connect Failure")
else:
    shell = subprocess.call("ssh -i " + filename + " " + args.ssh_user + "@" + ip, shell=True)

os.remove(filename)
