#!/usr/bin/env python3

# This file is part of tf-rddlsim.

# tf-rddlsim is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# tf-rddlsim is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with tf-rddlsim. If not, see <http://www.gnu.org/licenses/>.


import argparse
import numpy as np
import time

import rddlgym

from tfrddlsim.simulation.policy_simulator import PolicySimulator
import tfrddlsim.viz
import tfrddlsim.policy


def parse_args():
    description = 'RDDL2TensorFlow compiler and simulator'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument(
        'rddl',
        type=str,
        help='path to RDDL file or rddlgym problem id'
    )
    parser.add_argument(
        '--policy',
        default='random',
        choices=tuple(tfrddlsim.policy.policies),
        help='type of policy (default=random)'
    )
    parser.add_argument(
        '--viz',
        default='generic',
        choices=tuple(tfrddlsim.viz.visualizers),
        help='type of visualizer (default=generic)'
    )
    parser.add_argument(
        '-hr', '--horizon',
        type=int, default=40,
        help='number of timesteps of each trajectory (default=40)'
    )
    parser.add_argument(
        '-b', '--batch_size',
        type=int, default=75,
        help='number of trajectories in a batch (default=75)'
    )
    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='verbosity mode'
    )
    return parser.parse_args()


def print_simulation_parameters(args):
    print('*********************************************************')
    print(' RDDL domain = {}'.format(args.rddl))
    print('*********************************************************')
    print('>> policy = {}'.format(args.policy))
    print('>> horizon = {}'.format(args.horizon))
    print('>> batch_size = {}'.format(args.batch_size))
    print()


def get_policy(compiler, policy_type, batch_size):
    policy = tfrddlsim.policy.policies.get(policy_type, 'random')
    return policy(compiler, batch_size)


def simulate(compiler, policy, horizon, batch_size):
    start = time.time()
    simulator = PolicySimulator(compiler, policy, batch_size)
    trajectories = simulator.run(horizon)
    uptime = time.time() - start
    return uptime, trajectories


def report_performance_stats(trajectories):
    _, _, states, _, _, rewards = trajectories
    batch_size, horizon, _ = states[0][1].shape
    rewards = rewards = np.reshape(rewards, [batch_size, horizon])
    total_reward = np.sum(rewards, axis=1)
    avg_total_reward = np.mean(total_reward)
    stddev_total_reward = np.std(total_reward)
    print('*********************************************************')
    print(' PERFORMANCE STATS')
    print('*********************************************************')
    print('>> Average total reward = {:.6f}'.format(avg_total_reward))
    print('>> Stddev  total reward = {:.6f}'.format(stddev_total_reward))
    print('>> Total reward = {}'.format(list(total_reward)))
    print()


def report_time_stats(uptime, horizon, batch_size):
    time_per_batch = uptime / batch_size
    time_per_step = uptime / horizon
    print('*********************************************************')
    print(' TIME STATS:')
    print('*********************************************************')
    print('>> Simulation done in {:.6f} sec.'.format(uptime))
    print('>> Time per batch = {:.6f} sec.'.format(time_per_batch))
    print('>> Time per step  = {:.6f} sec.'.format(time_per_step))
    print()


def display(compiler, trajectories, visualizer_type, verbose):
    visualizer = tfrddlsim.viz.visualizers.get(visualizer_type, 'generic')
    viz = visualizer(compiler, verbose)
    viz.render(trajectories)


if __name__ == '__main__':

    args = parse_args()
    print_simulation_parameters(args)

    compiler = rddlgym.make(args.rddl, mode=rddlgym.SCG)
    compiler.batch_mode_on()

    policy = get_policy(compiler, args.policy, args.batch_size)

    uptime, trajectories = simulate(compiler, policy, args.horizon, args.batch_size)
    display(compiler, trajectories, args.viz, args.verbose)

    report_performance_stats(trajectories)
    report_time_stats(uptime, args.horizon, args.batch_size)
