import argparse
from functools import partial

from neorl.rl.baselines.shared import logger
from neorl.rl.baselines.shared.monitor import Monitor
from neorl.rl.baselines.shared import set_global_seeds
from neorl.rl.baselines.shared.atari_wrappers import make_atari
from neorl.rl.baselines.deepq import DQN, wrap_atari_dqn, CnnPolicy


def main():
    """
    Run the atari test
    """
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    policy = partial(CnnPolicy, dueling=args.dueling == 1)

    model = DQN(
        env=env,
        policy=policy,
        learning_rate=1e-4,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
    )
    model.learn(total_timesteps=args.num_timesteps)

    env.close()


if __name__ == '__main__':
    main()
