#!/usr/bin/env python

"""
dora the explorer

fits data using ANNs defined in explorer.py

has command line interface to experiment with hyperparameters
"""

import pugna.pugna_fit_cli
import pugna.explorer
import pugna.callbacks
import pugna.data
import pugna.model_utils
import pugna.logger
import pugna.learning_rate_schedulers
import pandas as pd
import numpy as np
import tensorflow as tf
import datetime
import os
import subprocess
import argparse
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use('agg')

mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update({"font.size": 16})

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--output-dir", type=str,
                        help="output directory", required=True)

    parser.add_argument("--X-scale-method", type=str, default="None",
                        help="sklearn method to scale X data", choices=["None", "MinMaxScaler", "StandardScaler"])
    parser.add_argument("--y-scale-method", type=str, default="None",
                        help="sklearn method to scale y data", choices=["None", "MinMaxScaler", "StandardScaler"])

    parser.add_argument("--X-data-train", type=str,
                        help="path to X.npy (training)", required=True)
    parser.add_argument("--y-data-train", type=str,
                        help="path to y.npy (training)", required=True)

    parser.add_argument("--X-data-val", type=str,
                        help="path to X.npy (validation)", required=True)
    parser.add_argument("--y-data-val", type=str,
                        help="path to y.npy (validation)", required=True)

    parser.add_argument("-v", "--verbose",
                        help="""
                        increase output verbosity
                        no -v: WARNING
                        -v: INFO
                        -vv: DEBUG
                        """,
                        action='count', default=0)

    # pugna fit NN args
    pugna.pugna_fit_cli.insert_pugna_fit_option_group(parser)

    args = parser.parse_args()

    # https://stackoverflow.com/questions/14097061/easier-way-to-enable-verbose-logging
    level = min(2, args.verbose)  # capped to number of levels
    logger = pugna.logger.init_logger(level=level)
    logger.info("running dora")
    logger.info(f"verbosity turned on at level: {level}")

    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)
    logger.info(
        f"tf using {tf.config.threading.get_inter_op_parallelism_threads()} inter_op_parallelism_threads thread(s)")
    logger.info(
        f"tf using {tf.config.threading.get_intra_op_parallelism_threads()} intra_op_parallelism_threads thread(s)")

    if "OMP_NUM_THREADS" not in os.environ:
        logger.info("'OMP_NUM_THREADS' not set. Setting it now.")
        os.environ["OMP_NUM_THREADS"] = "1"
    logger.info(f"OMP_NUM_THREADS: {os.environ['OMP_NUM_THREADS']}")

    if int(os.environ['OMP_NUM_THREADS']) != 1:
        logger.warning(
            f"OMP_NUM_THREADS is not 1! value: {os.environ['OMP_NUM_THREADS']}")

    logger.info("running 'tf.config.list_physical_devices('GPU')'")
    logger.info(tf.config.list_physical_devices('GPU'))

    try:
        logger.info("running 'nvidia-smi -L'")
        subprocess.call(["nvidia-smi", "-L"])
    except FileNotFoundError:
        logger.info("could not run 'nvidia-smi -L'")

    logger.info("NN SETTINGS:")
    logger.info(f"epochs: {args.epochs}")
    logger.info(f"n-blocks: {args.n_blocks}")
    logger.info(f"units-per-layer: {args.units_per_layer}")
    logger.info(f"layers-per-block: {args.layers_per_block}")
    logger.info(f"activation: {args.activation}")
    logger.info(f"leaky-relu-alpha: {args.leaky_relu_alpha}")
    logger.info(f"learning-rate: {args.learning_rate}")
    logger.info(f"use-lrs: {args.use_lrs}")
    logger.info(f"lrs-final-learning-rate: {args.lrs_final_learning_rate}")
    logger.info(f"lrs-decay-rate: {args.lrs_decay_rate}")
    logger.info(f"lrs-decay-steps: {args.lrs_decay_steps}")
    logger.info(f"batch-norm: {args.batch_norm}")
    logger.info(f"optimizer: {args.optimizer}")
    logger.info(f"loss: {args.loss}")
    logger.info(f"batch-size-factor: {args.batch_size_factor}")

    logger.info(f"making output dir: {args.output_dir}")
    os.makedirs(f"{args.output_dir}", exist_ok=True)

    logger.info(f"loading: {args.X_data_train}")
    X_train = np.load(args.X_data_train)

    logger.info(f"loading: {args.y_data_train}")
    y_train = np.load(args.y_data_train)

    logger.info(f"loading: {args.X_data_val}")
    X_val = np.load(args.X_data_val)

    logger.info(f"loading: {args.y_data_val}")
    y_val = np.load(args.y_data_val)

    logger.info(f"X_train.shape: {X_train.shape}")
    logger.info(f"y_train.shape: {y_train.shape}")
    logger.info(f"X_val.shape: {X_val.shape}")
    logger.info(f"y_val.shape: {y_val.shape}")

    logger.info(f"X scale method: {args.X_scale_method}")
    logger.info(f"y scale method: {args.y_scale_method}")

    if args.X_scale_method == "None":
        X_train_scaled = X_train.copy()
        X_val_scaled = X_val.copy()
    else:
        X_scalers = pugna.data.make_scalers(
            X_train, method=args.X_scale_method)
        X_train_scaled = pugna.data.apply_scaler(
            X_train, X_scalers)
        X_val_scaled = pugna.data.apply_scaler(
            X_val, X_scalers)

        outname = os.path.join(
            args.output_dir, f"X_scalers.npy")

        pugna.data.save_scalers(
            Scalers=X_scalers, filename=outname)

    if args.y_scale_method == "None":
        y_train_scaled = y_train.copy()
        y_val_scaled = y_val.copy()
    else:
        y_scalers = pugna.data.make_scalers(
            y_train, method=args.y_scale_method)
        y_train_scaled = pugna.data.apply_scaler(
            y_train, y_scalers)
        y_val_scaled = pugna.data.apply_scaler(
            y_val, y_scalers)

        outname = os.path.join(
            args.output_dir, f"y_scalers.npy")

        pugna.data.save_scalers(
            Scalers=y_scalers, filename=outname)

    ####
    batch_size = int(X_train.shape[0] / args.batch_size_factor)

    logger.info(f"X_train.shape: {X_train.shape[0]}")
    logger.info(f"batch size factor: {args.batch_size_factor}")
    logger.info(f"batch size: {batch_size}")

    if args.loss == 'mse':
        loss = 'mse'
    elif args.loss == 'mape':
        loss = tf.keras.losses.MeanAbsolutePercentageError()

    if args.optimizer == "Adam":
        optimizer = tf.keras.optimizers.Adam
    elif args.optimizer == "Nadam":
        optimizer = tf.keras.optimizers.Nadam
    elif args.optimizer == "Adadelta":
        optimizer = tf.keras.optimizers.Adadelta
    elif args.optimizer == "Adagrad":
        optimizer = tf.keras.optimizers.Adagrad
    elif args.optimizer == "Adamax":
        optimizer = tf.keras.optimizers.Adamax
    elif args.optimizer == "RMSprop":
        optimizer = tf.keras.optimizers.RMSprop
    elif args.optimizer == "SGD":
        optimizer = tf.keras.optimizers.SGD

    # learning_rate_fn = tf.keras.optimizers.schedules.InverseTimeDecay(
    #   args.learning_rate, args.lrs_decay_steps, args.lrs_decay_rate, staircase=True)

    learning_rate_fn = pugna.learning_rate_schedulers.InverseTimeDecay_WithFinalLR(
        args.learning_rate,
        args.lrs_final_learning_rate,
        args.lrs_decay_steps,
        args.lrs_decay_rate,
        staircase=True)

    logger.info("begining fits")

    X = X_train_scaled.copy()
    y = y_train_scaled.copy()
    X_val = X_val_scaled.copy()
    y_val = y_val_scaled.copy()

    input_dim = X.shape[1]
    output_dim = y.shape[1]

    model = pugna.explorer.build_model(
        input_dim=input_dim,
        output_dim=output_dim,
        units_per_layer=args.units_per_layer,
        layers_per_block=args.layers_per_block,
        n_blocks=args.n_blocks,
        activation=args.activation,
        batch_norm=args.batch_norm,
        summary=True)
    model = pugna.explorer.compile_model(
        model,
        learning_rate=args.learning_rate,
        loss=loss,
        optimizer=optimizer)
    callbacks = [pugna.callbacks.PrintDot()]
    if args.use_lrs:
        callbacks.append(
            tf.keras.callbacks.LearningRateScheduler(learning_rate_fn))

    starttime = datetime.datetime.now()
    history = model.fit(
        X,
        y,
        epochs=args.epochs,
        batch_size=batch_size,
        verbose=0,
        callbacks=callbacks,
        validation_data=(X_val, y_val)
    )
    endtime = datetime.datetime.now()

    duration = endtime - starttime

    logger.info("fits complete")
    logger.info(f"The time cost: {duration}")

    logger.info("saving model")
    outname = os.path.join(args.output_dir, f"model")
    pugna.model_utils.save_model_json(model, outname)
    pugna.model_utils.save_model_h5(model, outname)

    logger.info("saving history")
    outname = os.path.join(args.output_dir, f"history.pickle")
    pugna.model_utils.save_history(history.history, outname)

    logger.info("saving duration")
    outname = os.path.join(args.output_dir, f"duration.pickle")
    pugna.model_utils.save_datetime(duration, outname)

    last_loss = history.history['loss'][-1]
    logger.info(f"last loss: {last_loss}")
    last_val_loss = history.history['val_loss'][-1]
    logger.info(f"last val_loss: {last_val_loss}")

    if args.use_lrs:
        outname = os.path.join(args.output_dir, "lr.png")
        logger.info("saving learning rate plot")
        plt.figure(figsize=(14, 7))
        plt.plot(history.history['lr'])
        plt.yscale('log')
        plt.savefig(outname, bbox_inches='tight')
        plt.close()

    outname = os.path.join(args.output_dir, "loss.png")
    logger.info("saving loss plot")
    plt.figure(figsize=(14, 7))
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.yscale('log')
    plt.legend(loc='center left', fancybox=True,
               framealpha=0., bbox_to_anchor=(1.05, 0.5))
    plt.savefig(outname, bbox_inches='tight')
    plt.close()
