#!/usr/bin/env python3


# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
from parlai.core.params import ParlaiParser
from parlai.core.agents import create_agent
from parlai.mturk.core.mturk_manager import MTurkManager
import parlai.mturk.core.mturk_utils as mturk_utils

from parlai_internal.mturk.tasks.wizard_model_eval.mtdont import MTDONT_LIST
from parlai_internal.mturk.tasks.pairwise_dialogue_eval.conversation_and_likert_collection.worlds import ControllableDialogEval, PersonasGenerator, PersonaAssignWorld
from parlai_internal.mturk.tasks.pairwise_dialogue_eval.conversation_and_likert_collection.task_config import task_config
import parlai_internal.mturk.tasks.pairwise_dialogue_eval.conversation_and_likert_collection.model_configs as mcf

from threading import Lock
import gc
import datetime
import json
import logging
import os
import sys
import time
import copy
import random
import pprint
import numpy


# the top half of this block list is spammers, the bottom half is people i don't
# want to do any more of the task
BLOCKLIST2 = """
A110SC5K5Y3IHS A2P3WROT1JFU0C A30H28F41UBJGK A30QWXG142B3I7 A3U7EIF0P9GQKI
A18B1TWURSJFOL A25KM3PUOUZFN1 A30GPAEVFFIAIW A7SXWHGK8B40R ADYMIMV1PE14V
APQU1QQ738JV9 A12VQHPT0YHMIS A1GKOXB766FPYV A2YRYS17DU3A8R A3A8P4UR9A0DWQ
AZOK0NYBJ9ZCS ARB80JPV0HQ5O AZXHJOJDNOKY9 A3F51C49T9A34D A3H6KECCE83132
ALJMVJ1L3NUTE A1KEAHVVML6319
A3V8RTCR7QOGTU A2D30R5ZNQS01W


A2YQBBAK7NBZHU
A1QEQOI98976S0
ALQ7GPHT431Q2
AKVQTFH16ICGZ
A9KPD7Y55PU9Q
A33X8AKI34NP47
A2TZAXWOB3JMNV
A2ALYW6W16PUYQ
A207CEBUIVZ7IW
A1YHIQHLLLQIIQ
A17WS5TP6OPH7R
A2A4UAFZ5LW71K
A1JWKT0IS06YKL
A1ATL3G98SFW4V
A36IU5EX4AZD04
A1AGMI79B0958M
A1HCW1Z6KPJ6QX
AMOO2P3A36ULC
A3ZWMVK6GNTJ8
A1EQIKZRSGNNP
A1XATJA7ENOGWB
A3G8OON0TDPN1E
A3FKC2UR61C4T7
A3A865DIHN6C6
AWAW665TQQP2F
A1TYMXIYUUUL6F
A1LOD3LNX7FUPJ
AEWT71858NL5V
A1WKF2VH7TV0H2
A4158R4Y06ZB4
A2H50O04Y8V8MM
A1Z8AOIDT5IV43
A1MZ1EU82NE0PQ
A1EQ1LHEEIQ3UA
AHYVYJPFB6XGO
A355AKNHYKLVMD
A2JTV44ZIGMS5D
AFNG3YFO9EKK
A195R8LWDEDB6K
A30H28C1CN9YY1
A2UTM9YZ8J8IGX
AL5YST63LLYWK
A39CMUIYRDAW9P
A36Z48FR1LQXI9
A36WMK9XR6NT3N
A8WUIM7FKXJJ0
A3PTFE5MJV5DU1
A3FNC8ELMK8YJA
""".strip().split()
#
# MASTER_QUALIF = {
#     'QualificationTypeId': '2F1QJWKUDD8XADTFD2Q0G6UTO95ALH',
#     'Comparator': 'Exists',
#     'RequiredToPreview': True
# }

SETTINGS_TO_RUN = """
polyencoder
""".strip().split()
def make_flags(start_time):
    argparser = ParlaiParser(False, add_model_args=True)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.add_argument('--max-resp-time', default=240,
                           type=int,
                           help='time limit for entering a dialog message')
    argparser.add_argument('--max-choice-time', type=int,
                           default=300, help='time limit for turker'
                           'choosing the topic')
    argparser.add_argument('--ag-shutdown-time', default=120,
                           type=int,
                           help='time limit for entering a dialog message')
    argparser.add_argument('--num-turns', default=6, type=int,
                           help='number of turns of dialogue')
    argparser.add_argument('--human-eval', type='bool', default=False,
                           help='human vs human eval, no models involved')
    argparser.add_argument('--auto-approve-delay', type=int,
                           default=3600 * 24 * 2,
                           help='how long to wait for auto approval')
    argparser.add_argument('--only-masters', type='bool', default=False,
                           help='Set to true to use only master turks for '
                                'this test eval')
    argparser.add_argument('--create-model-qualif', type='bool', default=True,
                           help='Create model qualif so unique eval between'
                                'models.')
    argparser.add_argument('--limit-workers', type=int, default=len(SETTINGS_TO_RUN),
                           help='max HITs a worker can complete')
    argparser.add_argument('--mturk-log', type=str,
                           default=(
                                '$HOME/ParlAI/data/mturklogs/likert_convs/{}.log'
                                .format(start_time)))
    argparser.add_argument('--short-eval', type='bool', default=True,
                           help='Only ask engagingness question and persona'
                                'question.')
    # persona specific arguments
    argparser.add_argument('--persona-type', type=str, default='self',
                           choices=['self', 'other', 'none'])
    argparser.add_argument('--persona-datatype', type=str, default='valid',
                           choices=['train', 'test', 'valid'])
    argparser.add_argument('--max-persona-time', type=int, default=360,
                           help='max time to view persona')

    return argparser.parse_args()


def main(start_opt, start_time):
    """This task consists of an MTurk agent evaluating a Controllable Dialog model.
    """

    def get_logger(opt):
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)

        fmt = logging.Formatter(
            '%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
        console = logging.StreamHandler()
        console.setFormatter(fmt)
        logger.addHandler(console)
        if 'mturk_log' in opt:
            logfn = opt['mturk_log'].replace('$HOME', os.environ['HOME'])
            if not os.path.isdir(os.path.dirname(logfn)):
                raise OSError("Please run `mkdir -p {}`".format(os.path.dirname(logfn)))
            logfile = logging.FileHandler(logfn, 'a')
            logfile.setFormatter(fmt)
            logger.addHandler(logfile)
        logger.info('COMMAND: %s' % ' '.join(sys.argv))
        logger.info('-' * 100)
        logger.info('CONFIG:\n%s' % json.dumps(opt, indent=4, sort_keys=True))

        return logger

    task_config['task_description'] = task_config['task_description'].format(
        start_opt['reward']
    )

    # set options
    start_opt['limit_workers'] = len(SETTINGS_TO_RUN)
    start_opt['allowed_conversations'] = 1
    start_opt['max_hits_per_worker'] = start_opt['limit_workers']
    start_opt['task'] = os.path.basename(
        os.path.dirname(os.path.abspath(__file__)))

    start_opt.update(task_config)

    get_logger(start_opt)

    model_share_params = {}
    worker_models_seen = {}
    model_opts = {}
    model_counts = {}

    lock = Lock()


    for setup in SETTINGS_TO_RUN:
        assert 'human' not in setup
        model_counts[setup] = 0
        agent_config = getattr(mcf, setup)
        combined_config = copy.deepcopy(start_opt)
        for k, v in agent_config.items():
            combined_config[k] = v
            combined_config['override'][k] = v
        folder_name = '{}-{}'.format(setup, start_time)
        combined_config['save_data_path'] = os.path.join(
            os.getcwd(),
            'data',
            # 'fanout_controllable_dialog',
            folder_name
        )
        model_opts[setup] = combined_config
        bot = create_agent(combined_config, True)
        model_share_params[setup] = bot.share()

    if not start_opt.get('human_eval'):
        mturk_agent_ids = ['PERSON_1']
    else:
        mturk_agent_ids = ['PERSON_1', 'PERSON_2']

    mturk_manager = MTurkManager(
        opt=start_opt,
        mturk_agent_ids=mturk_agent_ids
    )

    personas_generator = PersonasGenerator(start_opt)

    random.seed(42)
    numpy.random.seed(42)

    directory_path = os.path.dirname(os.path.abspath(__file__))

    mturk_manager.setup_server(task_directory_path=directory_path)

    try:
        mturk_manager.start_new_run()
        agent_qualifications = []
        # assign qualifications

        qual_name = start_opt['block_qualification']
        qual_desc = (
            'Qualification to ensure workers complete only a certain'
            'number of these HITs'
        )
        qualification_id = mturk_utils.find_or_create_qualification(
            qual_name, qual_desc, False
        )
        print('Created qualification: ', qualification_id)
        # start_opt['unique_qualif_id'] = qualification_id

        if not start_opt['is_sandbox']:
            # ADD BLOCKED WORKERS HERE
            blocked_worker_list = MTDONT_LIST + BLOCKLIST2
            for w in blocked_worker_list:
                try:
                    print('Soft Blocking {}\n'.format(w))
                    mturk_manager.soft_block_worker(w)
                except Exception:
                    print('Did not soft block worker:', w)
                time.sleep(0.1)

        def run_onboard(worker):
            worker.personas_generator = personas_generator
            world = PersonaAssignWorld(start_opt, worker)
            world.parley()
            world.shutdown()

        def check_worker_eligibility(worker):
            worker_id = worker.worker_id
            lock.acquire()
            retval = len(worker_models_seen.get(worker_id, [])) < len(SETTINGS_TO_RUN)
            lock.release()
            return retval

        def assign_worker_roles(workers):
            for index, worker in enumerate(workers):
                worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()
        mturk_manager.create_hits(qualifications=agent_qualifications)

        def run_conversation(mturk_manager, opt, workers):
            conv_idx = mturk_manager.conversation_index

            # gotta find a bot this worker hasn't seen yet
            assert len(workers) == 1
            worker_id = workers[0].worker_id
            lock.acquire()
            if worker_id not in worker_models_seen:
                worker_models_seen[worker_id] = set()
            print("MODELCOUNTS:")
            print(pprint.pformat(model_counts))
            logging.info("MODELCOUNTS\n" + pprint.pformat(model_counts))
            model_options = [
                (model_counts[setup_name] + 10 * random.random(), setup_name)
                for setup_name in SETTINGS_TO_RUN
                if setup_name not in worker_models_seen[worker_id]
            ]
            if not model_options:
                lock.release()
                logging.error(
                    "Worker {} already finished all settings! Returning none"
                    .format(worker_id)
                )
                return None
            _, model_choice = min(model_options)

            worker_models_seen[worker_id].add(model_choice)
            model_counts[model_choice] += 1
            lock.release()

            world = ControllableDialogEval(
                opt=model_opts[model_choice],
                agents=workers,
                num_turns=start_opt['num_turns'],
                max_resp_time=start_opt['max_resp_time'],
                model_agent_opt=model_share_params[model_choice],
                world_tag='conversation t_{}'.format(conv_idx),
                agent_timeout_shutdown=opt['ag_shutdown_time'],
                model_config=model_choice,
            )
            world.reset_random()
            while not world.episode_done():
                world.parley()
            world.save_data()

            lock.acquire()
            if not world.convo_finished:
                model_counts[model_choice] -= 1
                worker_models_seen[worker_id].remove(model_choice)
            lock.release()

            world.shutdown()
            gc.collect()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation
        )

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()


if __name__ == '__main__':
    start_time = datetime.datetime.today().strftime('%Y-%m-%d-%H-%M')

    flags = make_flags(start_time)
    main(flags, start_time)
