#!/usr/bin/env python3

import sys
import argparse
import os
import multiprocessing as mp
import requests

def version():
    """
    Prints semver string.
    https://semver.org/
    """
    import pkgutil
    __version__ = pkgutil.get_data('msr', 'VERSION').decode().rstrip('\n')
    print(__version__)

def register():
    """
    Adds a url (if valid and not already added) to registry.
    Registry is located at $XDG_DATA_HOME/msr if set, otherwise located at
    $HOME/.local/share/msr.
    """
    if url_is_valid(args.url):
        # read registry to check if url has already been added
        mode = 'r' if os.path.exists(os.path.join(get_data_dir('msr'), 'registry')) else 'w+'
        with open(os.path.join(get_data_dir('msr'), 'registry'), mode) as fin:
            content = set(fin.readlines())

        if args.url + '\n' not in content:
            with open(os.path.join(get_data_dir('msr'), 'registry'), 'a') as fout:
                fout.write(args.url + '\n')
                print("Added " + args.url + " to registry.")
        else:
            print(args.url + " has already been registered.")
    else:
        print(args.url + " is not a valid url. Please ensure it is properly formatted.")
        sys.exit(1)

def measure():
    """
    Outputs size (in bytes) of the body received by making a GET request to each
    url in the registry.
    """
    with open(os.path.join(get_data_dir('msr'), 'registry')) as f:
        content = f.readlines()
    pool = mp.Pool(mp.cpu_count())
    sizes = pool.map_async(get_content_size, [x.strip() for x in content])

    pool.close()
    pool.join()

    # TODO: pretty print
    print(content)
    print(sizes.get())

def race():
    """
    Outputs the time it takes to reach each url in the registry.
    """
    with open(os.path.join(get_data_dir('msr'), 'registry')) as f:
        content = f.readlines()
    pool = mp.Pool(mp.cpu_count())
    times = pool.map_async(get_load_time, [x.strip() for x in content])

    pool.close()
    pool.join()

    # TODO: pretty print
    print(content)
    print(times.get())

# -- helper functions --

def get_load_time(url: str) -> float:
    try:
        r = requests.get(url)
    except requests.exceptions.RequestException as e:
        print(url + " resulted in RequestException: " + e)
        return -1
    return r.elapsed.total_seconds()

def get_content_size(url: str) -> int:
    """
    Returns -1 on exception.
    """

    session = requests.Session()
    session.max_redirects = 3
    try:
        r = session.get(url)
    except requests.exceptions.RequestException as e:
        print(url + " resulted in RequestException: " + e)
        return -1
    return len(r.content)

def url_is_valid(url: str) -> bool:
    from urllib.request import urlopen, URLError
    try:
        urlopen(url)
        return True
    except (URLError, ValueError):
        return False

def get_data_dir(pkg: str) -> str:
    if 'XDG_DATA_HOME' in os.environ:
        return os.path.join(os.environ['XDG_DATA_HOME'], pkg)
    elif 'HOME' in os.environ:
        return os.path.join(os.environ['HOME'], '.local/share', pkg)
    else:
        return os.path.join(os.path.expanduser('~'), '.local/share', pkg)

# -- parse CLI arguments --

function_map = {
    'version' : version,
    'register' : register,
    'measure' : measure,
    'race' : race
}

parser = argparse.ArgumentParser(description="Measure web pages.")
subparsers = parser.add_subparsers(dest='command')

parser_register = subparsers.add_parser('register')
parser_register.add_argument('url')
subparsers.add_parser('measure')
subparsers.add_parser('race')
subparsers.add_parser('version')
args = parser.parse_args()

if len(sys.argv) < 2:
    parser.print_help()
    sys.exit(1)

if not os.path.exists(get_data_dir('msr')):
    os.makedirs(get_data_dir('msr'))

function_map[args.command]()
