#!/usr/bin/env python
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0OA
#
# Authors:
# - Wen Guan, <wen.guan@cern.ch>, 2020 - 2022

"""
iDDS CLI
"""

from __future__ import print_function

import argparse
import argcomplete
# import json
import logging
import os
import sys
import time
# import traceback

from idds.client.version import release_version
from idds.client.clientmanager import ClientManager


def setup(args):
    cm = ClientManager(host=args.host, setup_client=False)
    cm.setup_local_configuration(local_config_root=args.local_config_root,
                                 config=args.config, host=args.host,
                                 auth_type=args.auth_type,
                                 auth_type_host=args.auth_type_host,
                                 x509_proxy=args.x509_proxy,
                                 vo=args.vo,
                                 oidc_token=args.oidc_token)
    return cm


def setup_oidc_token(args):
    cm = ClientManager(host=args.host, setup_client=True)
    cm.setup_oidc_token()


def clean_oidc_token(args):
    cm = ClientManager(host=args.host, setup_client=False)
    cm.clean_oidc_token()


def check_oidc_token_status(args):
    cm = ClientManager(host=args.host, setup_client=False)
    cm.check_oidc_token_status()


def refresh_oidc_token(args):
    cm = ClientManager(host=args.host, setup_client=False)
    cm.refresh_oidc_token()


def ping(args):
    cm = ClientManager(host=args.host, setup_client=False)
    status = cm.ping()
    print(status)


def get_requests_status(args):
    wm = ClientManager(host=args.host, setup_client=True)
    ret = wm.get_status(request_id=args.request_id, workload_id=args.workload_id, with_detail=args.with_detail)
    print(ret)


def abort_requests(args):
    wm = ClientManager(host=args.host, setup_client=True)
    wm.abort(request_id=args.request_id, workload_id=args.workload_id)


def suspend_requests(args):
    wm = ClientManager(host=args.host, setup_client=True)
    wm.suspend(request_id=args.request_id, workload_id=args.workload_id)


def resume_requests(args):
    wm = ClientManager(host=args.host, setup_client=True)
    wm.resume(request_id=args.request_id, workload_id=args.workload_id)


def retry_requests(args):
    wm = ClientManager(host=args.host, setup_client=True)
    wm.retry(request_id=args.request_id, workload_id=args.workload_id)


def finish_requests(args):
    wm = ClientManager(host=args.host, setup_client=True)
    wm.finish(request_id=args.request_id, workload_id=args.workload_id, set_all_finished=args.set_all_finished)


def download_logs(args):
    wm = ClientManager(host=args.host, setup_client=True)
    wm.download_logs(request_id=args.request_id, workload_id=args.workload_id, dest_dir=args.dest_dir, filename=args.dest_filename)


def upload_to_cacher(args):
    wm = ClientManager(host=args.host, setup_client=True)
    wm.upload_to_cacher(args.filename)


def download_from_cacher(args):
    wm = ClientManager(host=args.host, setup_client=True)
    wm.download_from_cacher(args.filename)


def get_hyperparameters(args):
    wm = ClientManager(host=args.host, setup_client=True)
    ret = wm.get_hyperparameters(workload_id=args.workload_id, request_id=args.request_id, id=args.id, status=args.status, limit=args.limit)
    # print(json.dumps(ret, sort_keys=True, indent=4))
    for k in ret:
        print(k)


def update_hyperparameter(args):
    wm = ClientManager(host=args.host, setup_client=True)
    ret = wm.update_hyperparameter(workload_id=args.workload_id, request_id=args.request_id, id=args.id, loss=args.loss)
    print(ret)


def get_messages(args):
    wm = ClientManager(host=args.host, setup_client=True)
    ret = wm.get_messages(request_id=args.request_id, workload_id=args.workload_id)
    status, msgs = ret
    print("status: %s" % status)
    print("messages: %s" % str(msgs))


def get_parser():
    """
    Return the argparse parser.
    """
    oparser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), add_help=True)
    subparsers = oparser.add_subparsers()

    # common items
    oparser.add_argument('--version', action='version', version='%(prog)s ' + release_version)
    oparser.add_argument('--local_config_root', dest="local_config_root", default=None, help="The root path of local configurations. Default is ~/.idds/.")
    oparser.add_argument('--config', dest=None, help="The iDDS configuration file to use. Default is ~/.idds/idds.cfg.")
    oparser.add_argument('--verbose', '-v', default=False, action='store_true', help="Print more verbose output.")
    oparser.add_argument('-H', '--host', dest="host", metavar="ADDRESS", help="The iDDS Rest host. For example: https://hostname:443/idds")

    # setup
    setup_parser = subparsers.add_parser('setup', help='Setup local configuration')
    setup_parser.set_defaults(function=setup)
    setup_parser.add_argument('--host', dest="auth_type_host", metavar="ADDRESS", help="The iDDS Rest host for the current auth type. For example: https://hostname:443/idds")
    setup_parser.add_argument('--auth_type', dest='auth_type', action='store', choices=['x509_proxy', 'oidc'], default=None, help='The auth_type in [x509_proxy, oidc]. Default is x509_proxy.')
    setup_parser.add_argument('--x509_proxy', dest='x509_proxy', action='store', default=None, help='The x509 proxy path. Default is /tmp/x509up_u%d.' % os.geteuid())
    setup_parser.add_argument('--vo', dest='vo', action='store', default=None, help='The virtual organization for authentication.')
    setup_parser.add_argument('--oidc_token', dest='oidc_token', default=None, help='The oidc token path. Default is {local_config_root}/.oidc_token.')

    # setup token
    token_setup_parser = subparsers.add_parser('setup_oidc_token', help='Setup authentication token')
    token_setup_parser.set_defaults(function=setup_oidc_token)
    # token_setup_parser.add_argument('--oidc_refresh_lifetime', dest='oidc_refresh_lifetime', default=None, help='The oidc refresh lifetime')
    # token_setup_parser.add_argument('--oidc_issuer', dest='oidc_issuer', default=None, help='The oidc issuer')
    # token_setup_parser.add_argument('--oidc_audience', dest='oidc_audience', default=None, help='The oidc audience')
    # token_setup_parser.add_argument('--oidc_token', dest='oidc_token', default=None, help='The oidc token path. Default is {local_config_root}/.oidc_token.')
    # token_setup_parser.add_argument('--oidc_auto', dest='oidc_auto', default=False, action='store_true', help='Get oidc token automatically, requiring oidc_username and oidc_password')
    # token_setup_parser.add_argument('--oidc_username', dest='oidc_username', default=None, help='The oidc username for getting oidc token, with --oidc_auto')
    # token_setup_parser.add_argument('--oidc_password', dest='oidc_password', default=None, help='The oidc password for getting oidc token, with --oidc_auto')
    # token_setup_parser.add_argument('--oidc_scope', dest='oidc_scope', default=None, help='The oidc scope. Default is openid profile.')
    # token_setup_parser.add_argument('--oidc_polling', dest='oidc_polling', default=False, help='whether polling oidc')
    # token_setup_parser.add_argument('--saml_username', dest='saml_username', default=None, help='The SAML username')
    # token_setup_parser.add_argument('--saml_password', dest='saml_password', default=None, help='The saml password')

    # clean token
    token_clean_parser = subparsers.add_parser('clean_oidc_token', help='Clean authentication token')
    token_clean_parser.set_defaults(function=clean_oidc_token)

    # check token status
    token_check_parser = subparsers.add_parser('get_oidc_token_info', help='Check authentication token information')
    token_check_parser.set_defaults(function=check_oidc_token_status)

    # refresh token
    token_refresh_parser = subparsers.add_parser('refresh_oidc_token', help='Refresh authentication token')
    token_refresh_parser.set_defaults(function=refresh_oidc_token)

    # ping
    ping_parser = subparsers.add_parser('ping', help='Ping idds server')
    ping_parser.set_defaults(function=ping)

    # get request status
    req_status_parser = subparsers.add_parser('get_requests_status', help='Get the requests status')
    req_status_parser.set_defaults(function=get_requests_status)
    req_status_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    req_status_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')
    req_status_parser.add_argument('--with_detail', dest='with_detail', default=False, action='store_true', help='To show detail status')

    # abort requests
    abort_parser = subparsers.add_parser('abort_requests', help='Abort requests')
    abort_parser.set_defaults(function=abort_requests)
    abort_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    abort_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')

    # suspend requests
    suspend_parser = subparsers.add_parser('suspend_requests', help='Suspend requests')
    suspend_parser.set_defaults(function=suspend_requests)
    suspend_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    suspend_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')

    # resume requests
    resume_parser = subparsers.add_parser('resume_requests', help='Resume requests')
    resume_parser.set_defaults(function=resume_requests)
    resume_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    resume_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')

    # retry requests
    retry_parser = subparsers.add_parser('retry_requests', help='Retry requests')
    retry_parser.set_defaults(function=retry_requests)
    retry_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    retry_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')

    # finish requests
    finish_parser = subparsers.add_parser('finish_requests', help='Finish requests')
    finish_parser.set_defaults(function=finish_requests)
    finish_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    finish_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')
    finish_parser.add_argument('--set_all_finished', default=False, action='store_true', help="Mark unfinished transformations as finished")

    # download logs
    log_parser = subparsers.add_parser('download_logs', help='Download logs')
    log_parser.set_defaults(function=download_logs)
    log_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    log_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')
    log_parser.add_argument('--dest_dir', dest='dest_dir', action='store', default='./', help='The destination directory')
    log_parser.add_argument('--dest_filename', dest='dest_filename', action='store', default=None, help='The destination filename')

    # upload a file to the cacher
    upload_parser = subparsers.add_parser('upload_to_cacher', help='Upload a file to the iDDS cacher on the server')
    upload_parser.set_defaults(function=upload_to_cacher)
    upload_parser.add_argument('--filename', dest='filename', action='store', default=None, help='The source filename. The destination filename on the server will be the base name of the file')

    # download a file from the cacher
    download_parser = subparsers.add_parser('download_from_cacher', help='Download a file from the iDDS cacher on the server')
    download_parser.set_defaults(function=download_from_cacher)
    download_parser.add_argument('--filename', dest='filename', action='store', default=None, help='The destination filename. The source filename on the server will be the base name of the file')

    # get hyperparameters
    hp_get_parser = subparsers.add_parser('get_hyperparameters', help='Get hyperparameters')
    hp_get_parser.set_defaults(function=get_hyperparameters)
    hp_get_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    hp_get_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')
    hp_get_parser.add_argument('--id', dest='id', action='store', type=int, help='The id of the hyperparameter')
    hp_get_parser.add_argument('--status', dest='status', action='store', help='Retrieve hyperparameters with defined status')
    hp_get_parser.add_argument('--limit', dest='limit', action='store', type=int, help='Limit number of hyperparameters')

    # update hyperparameter
    hp_update_parser = subparsers.add_parser('update_hyperparameter', help='Update the hyperparameter result')
    hp_update_parser.set_defaults(function=update_hyperparameter)
    hp_update_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    hp_update_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')
    hp_update_parser.add_argument('--id', dest='id', action='store', type=int, help='The id of the hyperparameter')
    hp_update_parser.add_argument('--loss', dest='loss', action='store', type=float, help='The loss result to be updated')

    # get messages
    get_messages_parser = subparsers.add_parser('get_messages', help='Get messages')
    get_messages_parser.set_defaults(function=get_messages)
    get_messages_parser.add_argument('--request_id', dest='request_id', action='store', type=int, help='The request id')
    get_messages_parser.add_argument('--workload_id', dest='workload_id', action='store', type=int, help='The workload id')

    return oparser


if __name__ == '__main__':
    arguments = sys.argv[1:]
    # set the configuration before anything else, if the config parameter is present
    for argi in range(len(arguments)):
        if arguments[argi] == '--config' and (argi + 1) < len(arguments):
            os.environ['IDDS_CONFIG'] = arguments[argi + 1]

    oparser = get_parser()
    argcomplete.autocomplete(oparser)

    if len(sys.argv) == 1:
        oparser.print_help()
        sys.exit(-1)

    args = oparser.parse_args(arguments)

    try:
        if args.verbose:
            logging.setLevel(logging.DEBUG)
        start_time = time.time()

        result = args.function(args)
        end_time = time.time()
        if args.verbose:
            print("Completed in %-0.4f sec." % (end_time - start_time))
        sys.exit(0)
    except Exception as error:
        logging.error("Strange error: {0}".format(error))
        # logging.error(traceback.format_exc())
        sys.exit(-1)
