import sys
import os
import gdown
from vilmedic.constants import DATA_DIR
import zipfile
from tqdm import tqdm

DATA_ZOO = {
    "RRS": ["1O7UebL2bcVF-vmoY2yyuoBAlG7M5KAKG", "15 Mb"],
    "RRG": ["1xUqztfFJz9pPYCVFJNSqkVDsUWyFqeId", "15 Mb"],
    "CLIP": ["1QdCGl8u5Q1__GezW_zHpwkHNvHkWbce5", "40 Mb"],
    "SELFSUP": ["1es-SrZaUdyyKf-1XnoirQ_Q0nZtQXtv0", "676.7 MB"],
    "MVQA": ["1YOr9fmc-zm8YwhWJIgjkjmF2Y0G_MSsV", "1 MB"],
}

IMAGE_ZOO = {
    "indiana-images-512": ["1Cv9u3ZoOb8hZVcObMqHRNxsoAdZpYGt4", "1.27 GB"],
    "imageclef-vqa-images-512": ["1ByPQ2TdXGz17pYDJKhrV05hka_XYH_pw", "327 MB"],
}


def download_images(data_name, file_id, unzip_dir):
    zip_name = os.path.join("data", data_name) + ".zip"
    target_dir = os.path.join(unzip_dir, data_name)

    if not os.path.exists(target_dir):
        gdown.download(url="https://drive.google.com/uc?id=" + file_id,
                       output=zip_name,
                       quiet=False)
        print("Unzipping...")
        with zipfile.ZipFile(zip_name, 'r') as zf:
            for member in tqdm(zf.infolist(), desc='Extracting '):
                zf.extract(member, unzip_dir)

        os.remove(zip_name)
        return
    print('{} already exists'.format(target_dir))


def download_data(file_id, unzip_dir):
    if not os.path.exists(unzip_dir):
        os.makedirs(unzip_dir, exist_ok=True)
        gdown.download_folder(id=file_id,
                              output=unzip_dir,
                              quiet=False)


if __name__ == '__main__':
    ALL_ZOO = {**IMAGE_ZOO, **DATA_ZOO}
    list_files = list(ALL_ZOO.keys())

    if len(sys.argv) == 2:
        res = sys.argv[1]
    else:
        for i, k in enumerate(list_files):
            print("{}. {} ({})".format(i + 1, k, ALL_ZOO[k][1]))

        res = input(
            "\nEnter the file number (1 or 2 for eg.) to download, or multiple numbers separated by a colon (1,3 for eg.):")

    if ',' in res:
        res = res.split(',')
    else:
        res = [res]

    # if we come from argv
    if len(sys.argv) == 2:
        try:
            res = [list_files.index(r.strip()) + 1 for r in res]
        except ValueError as e:
            sys.exit("{} of available downloads".format(e))
    else:
        res = [int(r.strip()) for r in res]

    assert all([0 < r <= len(list_files) for r in res]), "Numbers must be between 1 and {}".format(
        len(list_files))

    print("Selected downloads:")
    for r in res:
        print("\t{} [{}] in {}".format(list_files[r - 1],
                                       ALL_ZOO[list_files[r - 1]][1],
                                       DATA_DIR))

    for r in res:
        key = list_files[r - 1]
        file_id = ALL_ZOO[key][0]
        if 'images' in key:
            download_images(data_name=key, file_id=file_id, unzip_dir=os.path.join(DATA_DIR, "images"))
        else:
            download_data(file_id=file_id, unzip_dir=os.path.join(DATA_DIR, key))

    print("done")
