#!/usr/bin/env python
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0OA
#
# Authors:
# - Wen Guan, <wen.guan@cern.ch>, 2020

"""
iDDS CLI
"""

from __future__ import print_function

import argparse
import argcomplete
# import json
import logging
import os
import sys
import time

from idds.client.version import release_version
from idds.client.clientmanager import ClientManager


def get_requests_status(args):
    wm = ClientManager(host=args.host)
    wm.get_status(request_id=args.request_id, workload_id=args.workload_id, with_detail=args.with_detail)


def abort_requests(args):
    wm = ClientManager(host=args.host)
    wm.abort(request_id=args.request_id, workload_id=args.workload_id)


def suspend_requests(args):
    wm = ClientManager(host=args.host)
    wm.suspend(request_id=args.request_id, workload_id=args.workload_id)


def resume_requests(args):
    wm = ClientManager(host=args.host)
    wm.resume(request_id=args.request_id, workload_id=args.workload_id)


def download_logs(args):
    wm = ClientManager(host=args.host)
    wm.download_logs(request_id=args.request_id, workload_id=args.workload_id, dest_dir=args.dest_dir, filename=args.dest_filename)


def upload_to_cacher(args):
    wm = ClientManager(host=args.host)
    wm.upload_to_cacher(args.filename)


def download_from_cacher(args):
    wm = ClientManager(host=args.host)
    wm.download_from_cacher(args.filename)


def get_hyperparameters(args):
    wm = ClientManager(host=args.host)
    ret = wm.get_hyperparameters(workload_id=args.workload_id, request_id=args.request_id, id=args.id, status=args.status, limit=args.limit)
    # print(json.dumps(ret, sort_keys=True, indent=4))
    for k in ret:
        print(k)


def update_hyperparameter(args):
    wm = ClientManager(host=args.host)
    ret = wm.update_hyperparameter(workload_id=args.workload_id, request_id=args.request_id, id=args.id, loss=args.loss)
    print(ret)


def get_parser():
    """
    Return the argparse parser.
    """
    oparser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), add_help=True)
    subparsers = oparser.add_subparsers()

    # common items
    oparser.add_argument('--version', action='version', version='%(prog)s ' + release_version)
    oparser.add_argument('--config', dest="config", help="The iDDS configuration file to use.")
    oparser.add_argument('--verbose', '-v', default=False, action='store_true', help="Print more verbose output.")
    oparser.add_argument('-H', '--host', dest="host", metavar="ADDRESS", help="The iDDS Rest host. For example: https://iddsserver.cern.ch:443/idds")

    # get request status
    req_status_parser = subparsers.add_parser('get_requests_status', help='Get the requests status')
    req_status_parser.set_defaults(function=get_requests_status)
    req_status_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    req_status_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')
    req_status_parser.add_argument('--with_detail', dest='with_detail', default=False, action='store_true', help='To show detail status')

    # abort requests
    abort_parser = subparsers.add_parser('abort_requests', help='Abort requests')
    abort_parser.set_defaults(function=abort_requests)
    abort_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    abort_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')

    # suspend requests
    suspend_parser = subparsers.add_parser('suspend_requests', help='Suspend requests')
    suspend_parser.set_defaults(function=suspend_requests)
    suspend_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    suspend_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')

    # resume requests
    resume_parser = subparsers.add_parser('resume_requests', help='Resume requests')
    resume_parser.set_defaults(function=resume_requests)
    resume_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    resume_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')

    # download logs
    log_parser = subparsers.add_parser('download_logs', help='Download logs')
    log_parser.set_defaults(function=download_logs)
    log_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    log_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')
    log_parser.add_argument('--dest_dir', dest='dest_dir', action='store', default='./', help='The destination directory')
    log_parser.add_argument('--dest_filename', dest='dest_filename', action='store', default=None, help='The destination filename')

    # upload a file to the cacher
    upload_parser = subparsers.add_parser('upload_to_cacher', help='Upload a file to the iDDS cacher on the server')
    upload_parser.set_defaults(function=upload_to_cacher)
    upload_parser.add_argument('--filename', dest='filename', action='store', default=None, help='The source filename. The destination filename on the server will be the base name of the file')

    # download a file from the cacher
    download_parser = subparsers.add_parser('download_from_cacher', help='Download a file from the iDDS cacher on the server')
    download_parser.set_defaults(function=download_from_cacher)
    download_parser.add_argument('--filename', dest='filename', action='store', default=None, help='The destination filename. The source filename on the server will be the base name of the file')

    # get hyperparameters
    hp_get_parser = subparsers.add_parser('get_hyperparameters', help='Get hyperparameters')
    hp_get_parser.set_defaults(function=get_hyperparameters)
    hp_get_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    hp_get_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')
    hp_get_parser.add_argument('--id', dest='id', action='store', type=int, help='The id of the hyperparameter')
    hp_get_parser.add_argument('--status', dest='status', action='store', help='Retrieve hyperparameters with defined status')
    hp_get_parser.add_argument('--limit', dest='limit', action='store', type=int, help='Limit number of hyperparameters')

    # update hyperparameter
    hp_update_parser = subparsers.add_parser('update_hyperparameter', help='Update the hyperparameter result')
    hp_update_parser.set_defaults(function=update_hyperparameter)
    hp_update_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    hp_update_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')
    hp_update_parser.add_argument('--id', dest='id', action='store', type=int, help='The id of the hyperparameter')
    hp_update_parser.add_argument('--loss', dest='loss', action='store', type=float, help='The loss result to be updated')

    return oparser


if __name__ == '__main__':
    arguments = sys.argv[1:]
    # set the configuration before anything else, if the config parameter is present
    for argi in range(len(arguments)):
        if arguments[argi] == '--config' and (argi + 1) < len(arguments):
            os.environ['IDDS_CONFIG'] = arguments[argi + 1]

    oparser = get_parser()
    argcomplete.autocomplete(oparser)

    if len(sys.argv) == 1:
        oparser.print_help()
        sys.exit(-1)

    args = oparser.parse_args(arguments)

    try:
        if args.verbose:
            logging.setLevel(logging.DEBUG)
        start_time = time.time()
        result = args.function(args)
        end_time = time.time()
        if args.verbose:
            print("Completed in %-0.4f sec." % (end_time - start_time))
        sys.exit(0)
    except Exception as error:
        logging.error("Strange error: {0}".format(error))
        sys.exit(-1)
