#!/usr/bin/env python3

# This file is part of tf-mdp.

# tf-mdp is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# tf-mdp is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with tf-mdp. If not, see <http://www.gnu.org/licenses/>.


import argparse
import time

import rddlgym
import tensorflow as tf

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


def parse_args():
    description = 'Probabilistic planning in continuous state-action MDPs using TensorFlow.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument(
        'rddl',
        type=str,
        help='RDDL file or rddlgym domain id'
    )
    parser.add_argument(
        '-l', '--layers',
        nargs='+',
        default=[],
        help='number of units in each hidden layer in policy network'
    )
    parser.add_argument(
        '-a', '--activation',
        type=str, choices=['none', 'sigmoid', 'tanh', 'relu', 'relu6', 'crelu', 'elu', 'selu', 'softplus', 'softsign'],
        default='elu',
        help='activation function for hidden layers in policy network'
    )
    parser.add_argument(
        '-iln', '--input-layer-norm',
        action='store_true',
        help='input layer normalization flag'
    )
    # parser.add_argument(
    #     '-hln', '--hidden-layer-norm',
    #     action='store_true',
    #     help='hidden layer normalization flag'
    # )
    # parser.add_argument(
    #     '-kl1', '--kernel-l1-regularizer',
    #     type=float, default=0.0,
    #     help='kernel L1-regularizer constant (default=0.0)'
    # )
    # parser.add_argument(
    #     '-kl2', '--kernel-l2-regularizer',
    #     type=float, default=0.0,
    #     help='kernel L2-regularizer constant (default=0.0)'
    # )
    # parser.add_argument(
    #     '-bl1', '--bias-l1-regularizer',
    #     type=float, default=0.0,
    #     help='bias L1-regularizer constant (default=0.0)'
    # )
    # parser.add_argument(
    #     '-bl2', '--bias-l2-regularizer',
    #     type=float, default=0.0,
    #     help='bias L2-regularizer constant (default=0.0)'
    # )
    parser.add_argument(
        '-b', '--batch-size',
        type=int, default=256,
        help='number of trajectories in a batch (default=256)'
    )
    parser.add_argument(
        '-hr', '--horizon',
        type=int, default=40,
        help='number of timesteps (default=40)'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int, default=200,
        help='number of timesteps (default=200)'
    )
    parser.add_argument(
        '-lr', '--learning-rate',
        type=float, default=0.001,
        help='optimizer learning rate (default=0.001)'
    )
    parser.add_argument(
        '-opt', '--optimizer',
        type=str, choices=['Adadelta', 'Adagrad', 'Adam', 'GradientDescent', 'ProximalGradientDescent', 'ProximalAdagrad', 'RMSProp'],
        default='RMSProp',
        help='loss optimizer (default=RMSProp)'
    )
    parser.add_argument(
        '-lfn', '--loss-fn',
        type=str, choices=['linear', 'mse'],
        default='linear',
        help='loss function (default=linear)'
    )
    parser.add_argument(
        '-ld', '--logdir',
        type=str, default='/tmp/tfmdp',
        help='log directory for data summaries (default=/tmp/tfmdp)'
    )
    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='verbosity mode'
    )
    return parser.parse_args()


def print_parameters(args):
    if args.verbose:
        import tfmdp
        print()
        print('Running tf-mdp v{} ...'.format(tfmdp.__version__))
        print()
        print('>> RDDL:   {}'.format(args.rddl))
        print('>> logdir: {}'.format(args.logdir))
        print()
        print('>> Policy Net:')
        print('layers = [{}]'.format(','.join(args.layers)))
        print('activation = {}'.format(args.activation))
        print('input  layer norm = {}'.format(args.input_layer_norm))
        # print('hidden layer norm = {}'.format(args.hidden_layer_norm))
        print()
        print('>> Hyperparameters:')
        print('epochs        = {}'.format(args.epochs))
        print('learning rate = {}'.format(args.learning_rate))
        print('batch size    = {}'.format(args.batch_size))
        print('horizon       = {}'.format(args.horizon))
        print()
        print('>> Optimization:')
        print('optimizer     = {}'.format(args.optimizer))
        print('loss function = {}'.format(args.loss_fn))
        # print('kernel l1-regularization = {}'.format(args.kernel_l1_regularizer))
        # print('kernel l2-regularization = {}'.format(args.kernel_l2_regularizer))
        # print('bias l1-regularization = {}'.format(args.bias_l1_regularizer))
        # print('bias l2-regularization = {}'.format(args.bias_l2_regularizer))
        print()


def load_model(args):
    compiler = rddlgym.make(args.rddl, mode=rddlgym.SCG)
    compiler.init()
    compiler.batch_size = args.batch_size
    return compiler


def build_policy(compiler, args):
    from tfmdp.policy.feedforward import FeedforwardPolicy
    config = {
        'layers': args.layers,
        'activation': args.activation,
        'input_layer_norm': args.input_layer_norm
    }
    policy = FeedforwardPolicy(compiler, config)
    policy.build()
    return policy


def solve(compiler, policy, args):
    from tfmdp.planning.pathwise import PathwiseOptimizationPlanner
    config = {
        'batch_size': args.batch_size,
        'horizon': args.horizon,
        'learning_rate': args.learning_rate,
        'logdir': args.logdir
    }
    planner = PathwiseOptimizationPlanner(compiler, config)
    planner.build(policy, loss=args.loss_fn, optimizer=args.optimizer)
    _, rewards = planner.run(args.epochs)
    return rewards


def report_performance(rewards, horizon):
    reward = rewards[-1][1]
    print('>> Performance:')
    print('total reward = {:.4f}, reward per timestep = {:.4f}\n'.format(reward, reward / horizon))


if __name__ == '__main__':

    args = parse_args()
    print_parameters(args)

    print('>> Loading model ...')
    start = time.time()
    compiler = load_model(args)
    end = time.time()
    uptime = end - start
    print('Done in {:.6f} sec.'.format(uptime))
    print()

    print('>> Optimizing...')
    start = time.time()
    drp = build_policy(compiler, args)
    rewards = solve(compiler, drp, args)
    end = time.time()
    uptime = end - start
    print()
    print('Done in {:.6f} sec.'.format(uptime))
    print()

    report_performance(rewards, args.horizon)

    print('tensorboard --logdir {}\n'.format(args.logdir))
