#!/usr/bin/env python3
import argparse
import sys, os
import errno, pathlib, re
import datetime, time
import json, requests, urllib.request
import sqlite3
import logging


### urlscan's config directory
urlscan_dir = str(pathlib.Path.home()) + '/.urlscan'

saved_scan_dir = urlscan_dir + '/saved_scans'

### urlscan's default local database
urlscan_default_db = urlscan_dir + '/urlscan.db'


if not os.path.exists(urlscan_dir):
        os.makedirs(urlscan_dir)


## argparse arguments
parser = argparse.ArgumentParser(description="Wrapper for urlscan.io's API")
subparsers = parser.add_subparsers(help='commands', dest='command')

## Init subparser
parser_init = subparsers.add_parser('init', help='initialize urlscan-py with API key')
parser_init.add_argument('--api', help='urlscan API key', metavar='KEY')
parser_init.add_argument('--db', help='specify different database file to search', metavar='FILE', default=urlscan_default_db)

## Scan parser
parser_scan = subparsers.add_parser('scan', help='scan a url')
parser_scan.add_argument('--url', help='URL(s) to scan', nargs='+', metavar='URL')
parser_scan.add_argument('--db', help='specify different database file initiated scans will be saved to', metavar='FILE', default=urlscan_default_db)
parser_scan.add_argument('-p', '--public', help='submit a public urlscan scan', action="store_true")
parser_scan.add_argument('-f', '--file', help='file with url(s) to scan')
parser_scan.add_argument('-q', '--quiet', help='suppress output', action="store_true")
parser_scan.add_argument('--api', help='urlscan API key', metavar='KEY')

## Search parser
parser_search = subparsers.add_parser('search', help='search database for UUID of url')
parser_search.add_argument('--url', help='url(s) to search for matching UUID', nargs='+', metavar='<URL>, <all>', required='True')
parser_search.add_argument('--db', help='specify different database file to search', metavar='FILE', default=urlscan_default_db)
parser_search.add_argument('--web', help='search urlscan.io for URL (public)', action="store_true")

## Retrieve parser
parser_retrieve = subparsers.add_parser('retrieve', help='retrieve scan results')
parser_retrieve.add_argument('--uuid', help='UUID(s) to retrieve scans for', nargs='+', metavar='UUID', required='True')
parser_retrieve.add_argument('--db', help='specify different database file to query', metavar='FILE', default=urlscan_default_db)
parser_retrieve.add_argument('--api', help='urlscan API key', metavar='KEY')
parser_retrieve.add_argument('-s', '--summary', help='print summary of result', action="store_true")
parser_retrieve.add_argument('-d', '--dir', help='directory to save scans to', metavar='DIRECTORY', default=saved_scan_dir)
parser_retrieve.add_argument('--dom', help='save dom file from retrieved result', action="store_true")
parser_retrieve.add_argument('--png', help='save screenshot as png', action="store_true")
parser_retrieve.add_argument('-q', '--quiet', help='suppress output', action="store_true")

args = parser.parse_args()


def connect_db():
    global conn
    conn = sqlite3.connect(args.db)
    global c
    c = conn.cursor()


def add_key_value():
    connect_db()
    global urlscan_api
    try:
        c.execute('''CREATE TABLE api (key TEXT PRIMARY KEY)''')
    except sqlite3.OperationalError:
        pass
    if args.api:
        urlscan_api = args.api
    else:
        urlscan_api = input('Please enter API key: ')
    c.execute("INSERT OR REPLACE INTO api(key) VALUES (?)", (urlscan_api,))
    conn.commit()
    sys.exit(0)


def get_key_value():
    connect_db()
    global urlscan_api
    try:
        c.execute("SELECT * FROM api")
    except sqlite3.OperationalError:
        add_key_value()
    db_extract = c.fetchone()
    try:
        urlscan_api = ''.join(db_extract)
    except TypeError:
        logging.error('Invalid API entry in database.')
        sys.exit(1)

def initialize():
    global urlscan_api
    if args.command == 'init':
        try:
            get_key_value()
            overwrite = input('API already exists in database. Overwrite? (y/n)')
            if overwrite == 'y':
                add_key_value()
        except sqlite3.OperationalError:
            add_key_value()
        sys.exit(0)
    if args.api:
        urlscan_api = args.api
    else:
        get_key_value()


def submit():
    if args.file: 
        urls_to_scan = [line.rstrip('\n') for line in open(args.file)]
    else:
        urls_to_scan = args.url

    for target_url in urls_to_scan:
        headers = {
            'Content-Type': 'application/json',
            'API-Key': urlscan_api,
        }

        if not args.public:
            data = '{"url": "%s"}' % target_url
        else:
            data = '{"url": "%s", "public": "on"}' % target_url 

        response = requests.post('https://urlscan.io/api/v1/scan/', headers=headers, data=data)

        ## end POST request

        r = response.content.decode("utf-8")

        if not args.quiet:
            print(r)

        if args.db:
            save_history(target_url, r)


        time.sleep(3)


def remote_search(url):
    if '://' in url:
        url = url.split("://")[1]

    params = (
        ('q', 'domain:%s' % url),
    )
    response = requests.get('https://urlscan.io/api/v1/search/', params=params)
    r = response.content.decode("utf-8")
    print(r)
    time.sleep(2)


def search(search_urls, remote_true):
    search_urls
    connect_db()
    for url in search_urls:
        if remote_true:
            remote_search(url)
        else:
            if url == 'all':
                c.execute('SELECT url, uuid, datetime FROM scanned_urls')
                for line in c.fetchall():
                    if line[0-2]:
                        scan_url = line[0]
                        scan_uuid = line[1]
                        scan_date = line[2]
                        print(scan_date + ' || ' + scan_url  + ': ' + scan_uuid)
            else:
                t = (url,)
                c.execute("SELECT url, uuid, datetime FROM scanned_urls WHERE url LIKE ?", ['%'+url+'%'])
                for line in c.fetchall():
                    if line[0-2]:
                        scan_url = line[0]
                        scan_uuid = line[1]
                        scan_date = line[2]
                        print(scan_date + ' || ' + scan_url  + ': ' + scan_uuid)


    

def download_dom(target_uuid, target_dir, save_template):
    dom_url = 'https://urlscan.io/dom/' + target_uuid + '/'
    try:
        os.makedirs(target_dir)
    except FileExistsError:
        pass
    target_dom = save_template + '.dom'
    try:
        urllib.request.urlretrieve(dom_url, str(target_dom))
    except FileExistsError:
        pass
    

def download_png(target_uuid, target_dir, save_template):
    png_url = 'https://urlscan.io/screenshots/' + target_uuid + '.png'
    try:
        os.makedirs(target_dir)
    except FileExistsError:
        pass
    target_png = save_template + '.png'
    try:
        urllib.request.urlretrieve(png_url, str(target_png))
    except FileExistsError:
        pass

def print_summary(content):
    ### relevant aggregate data
    request_info = content.get("data").get("requests")
    meta_info = content.get("meta")
    verdict_info = content.get("verdicts")
    list_info = content.get("lists")
    stats_info = content.get("stats")
    page_info = content.get("page")
    
    ### more specific data
    geoip_info = meta_info.get("processors").get("geoip") 
    web_apps_info = meta_info.get("processors").get("wappa")
    resource_info = stats_info.get("resourceStats")
    protocol_info = stats_info.get("protocolStats")
    ip_info = stats_info.get("ipStats")

    ### enumerate countries 
    countries = []
    for item in resource_info:
        country_list = item.get("countries")
        for country in country_list:
            if country not in countries:
                countries.append(country)

    ### enumerate web apps
    web_apps = []
    for app in web_apps_info.get("data"):
        web_apps.append(app.get("app"))
    
    ### enumerate domains pointing to ip
    pointed_domains = []
    for ip in ip_info:
        domain_list = ip.get("domains")
        for domain in domain_list:
            if domain not in pointed_domains:
                pointed_domains.append(domain)


    ### data for summary
    page_domain = page_info.get("domain")
    page_ip = page_info.get("ip")
    page_country = page_info.get("country")
    page_server = page_info.get("server")
    ads_blocked = stats_info.get("adBlocked")
    https_percentage = stats_info.get("securePercentage")
    ipv6_percentage = stats_info.get("IPv6Percentage")
    country_count = stats_info.get("uniqCountries")
    num_requests = len(request_info)
    is_malicious = verdict_info.get("overall").get("malicious")
    malicious_total = verdict_info.get("engines").get("maliciousTotal")
    ip_addresses = list_info.get("ips")
    urls = list_info.get("urls")

    
    ### print data
    if str(page_ip) != "None":
        print("Domain: " + page_domain)
        print("IP Address: " + str(page_ip))
        print("Country: " + page_country)
        print("Server: " + page_server)
        print("Web Apps: " + str(web_apps))
        print("Number of Requests: " + str(num_requests))
        print("Ads Blocked: " + str(ads_blocked))
        print("HTTPS Requests: " + str(https_percentage) + "%")
        print("IPv6: " + str(ipv6_percentage) + "%")
        print("Unique Country Count: " + str(country_count))
        print("Malicious: " + str(is_malicious))
        print("Malicious Requests: " + str(malicious_total))
        print("Pointed Domains: " + str(pointed_domains))



def query(uuid):
    for target_uuid in uuid:
        response = requests.get("https://urlscan.io/api/v1/result/%s" % target_uuid)
        null_response_string = '"status": 404'

        r = response.content.decode("utf-8")
        
        if null_response_string in r:
            print('Results not processed. Please check again later.')
            sys.exit(5)

        if not args.quiet:
            if not args.summary:
                print(r)

        #formatted_uuid = target_uuid.replace("-", "_")
        url = response.json().get("task").get("url").split("://")[1]
        submission_time = response.json().get("task").get("time")
        target_dir = args.dir + '/' + url + '/'

        save_template = target_dir + submission_time + '_' + target_uuid

        if hasattr(args, 'dir'):
            save_to_dir(target_dir, save_template, str(r))
        
        if args.dom:
            download_dom(target_uuid, target_dir, save_template)
        if args.png:
            download_png(target_uuid, target_dir, save_template)
        if args.summary:
            print_summary(response.json())

        time.sleep(3)

def save_history(target_url, r):
    ### extract UUID from json
    matched_lines = [line for line in r.split('\n') if "uuid" in line]
    result = ''.join(matched_lines)
    result = result.split(":",1)[1]
    uuid = re.sub(r'[^a-zA-Z0-9=-]', '', result)
    ### end UUID extraction

    target_url = str(target_url)
    current_time = int(time.time())
    human_readable_time = str(datetime.datetime.fromtimestamp(current_time))
    connect_db()
    c.execute('''CREATE TABLE IF NOT EXISTS scanned_urls (url, uuid, datetime TEXT PRIMARY KEY)''')
    c.execute("INSERT OR REPLACE INTO scanned_urls VALUES (?, ?, ?)", (target_url, uuid, human_readable_time))
    conn.commit()
    conn.close()


def save_to_dir(target_dir, save_template, r):
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    save_file_name = save_template + '.json'

    path_to_file = pathlib.Path(save_file_name)
    if not path_to_file.is_file():
        with open(save_file_name, 'a') as out:
            out.write(r)

    

def main():
    if args.command == 'init':
        initialize()

    if args.command == 'scan':
        initialize()
        submit()

    if args.command == 'search':
        search(args.url, args.web)

    if args.command == 'retrieve':
        initialize()
        query(args.uuid)



if __name__ == '__main__':
    main()
