#!/usr/bin/env python

import pugna.models.mscalednn

import pugna.logger
import pugna.model_utils
import tensorflow as tf
import numpy as np
import subprocess
import datetime
import os
import argparse
from distutils import util
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use('agg')

mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update({"font.size": 16})


def plot_1d_prediction(x, y, model, output_dir, tag):
    outname = os.path.join(output_dir, f"plot_1d_prediction_{tag}.png")
    logger.info(f"saving plot_1d_prediction: {outname}")
    yhat = model.predict(x)
    plt.figure(figsize=(14, 7))
    plt.plot(x, y, 'x', label='data')
    plt.plot(x, yhat, ' o', markerfacecolor='none', label='model')
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.savefig(outname, bbox_inches='tight')
    plt.close()


def plot_prediction(x, y, model, output_dir, tag):
    outname = os.path.join(output_dir, f"plot_prediction_{tag}.png")
    logger.info(f"saving plot_prediction: {outname}")
    yhat = model.predict(x)
    plt.figure(figsize=(14, 7))
    plt.plot(y, 'x', label='data')
    plt.plot(yhat, ' o', markerfacecolor='none', label='model')
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.savefig(outname, bbox_inches='tight')
    plt.close()


def plot_residual(x, y, model, output_dir, tag):
    outname = os.path.join(output_dir, f"plot_residuals_{tag}.png")
    logger.info(f"saving plot_residual: {outname}")
    yhat = model.predict(x)
    plt.figure(figsize=(14, 7))
    plt.plot(y - yhat, 'o')
    plt.savefig(outname, bbox_inches='tight')
    plt.close()


def plot_learning_rate(output_dir, history):
    outname = os.path.join(output_dir, "lr.png")
    logger.info(f"saving learning rate plot: {outname}")
    plt.figure(figsize=(14, 7))
    plt.plot(history.history['lr'])
    plt.yscale('log')
    plt.savefig(outname, bbox_inches='tight')
    plt.close()


def plot_loss(output_dir, history):
    outname = os.path.join(output_dir, "loss.png")
    logger.info(f"saving loss plot: {outname}")
    plt.figure(figsize=(14, 7))
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.yscale('log')
    plt.legend(loc='center left', fancybox=True,
               framealpha=0., bbox_to_anchor=(1.05, 0.5))
    plt.savefig(outname, bbox_inches='tight')
    plt.close()


def broadcast_inputs_to_length(inputs, desired_length):
    """
    if len(inputs) is equal to desired length then returns inputs

    if len(inputs) is 1 then it copies that value desired_length times.

    if len(inputs) > 1 but not equal to desired_length then Raises ValueError

    Args:
        inputs (list): list of length 1 or desired_length
        desired_length (int): [description]

    Returns:
        [list]: a list of length 'desired_length' where the value has been
        duplicated.
    """

    if (len(inputs) == 1) & (desired_length > 1):
        # this line returns a list where the item is repeated desired_length times.
        outputs = [inputs[0] for i in range(desired_length)]
    elif len(inputs) != desired_length:
        raise ValueError("length of input list doesn't match desired length")
    else:
        outputs = inputs

    return outputs


def convert_strlist_to_bool(strlist):
    """https://stackoverflow.com/a/56574765/12840171

    converts an input string list such as
    ['True', 'False', 'False']
    to a list of bools
    i.e.
    [True, False, False]
    """
    def func(x): return bool(util.strtobool(x))
    return list(map(func, strlist))


def check_gpu():
    logger.info("running 'tf.config.list_physical_devices('GPU')'")
    logger.info(tf.config.list_physical_devices("GPU"))

    try:
        logger.info("running 'nvidia-smi -L'")
        subprocess.call(["nvidia-smi", "-L"])
    except FileNotFoundError:
        logger.info("could not run 'nvidia-smi -L'")


def set_threads():
    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)

    logger.info(
        f"tf using {tf.config.threading.get_inter_op_parallelism_threads()} inter_op_parallelism_threads thread(s)"
    )
    logger.info(
        f"tf using {tf.config.threading.get_intra_op_parallelism_threads()} intra_op_parallelism_threads thread(s)"
    )

    if "OMP_NUM_THREADS" not in os.environ:
        logger.info("'OMP_NUM_THREADS' not set. Setting it now.")
        os.environ["OMP_NUM_THREADS"] = "1"
    logger.info(f"OMP_NUM_THREADS: {os.environ['OMP_NUM_THREADS']}")

    if int(os.environ["OMP_NUM_THREADS"]) != 1:
        logger.warning(
            f"OMP_NUM_THREADS is not 1! value: {os.environ['OMP_NUM_THREADS']}"
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--output-dir", type=str,
                        help="output directory", required=True)

    parser.add_argument("--X-data-train", type=str,
                        help="path to X.npy (training)", required=True)
    parser.add_argument("--y-data-train", type=str,
                        help="path to y.npy (training)", required=True)

    parser.add_argument("--X-data-val", type=str,
                        help="path to X.npy (validation)", required=True)
    parser.add_argument("--y-data-val", type=str,
                        help="path to y.npy (validation)", required=True)

    parser.add_argument("-v", "--verbose",
                        help="""
                        increase output verbosity
                        no -v: WARNING
                        -v: INFO
                        -vv: DEBUG
                        """,
                        action='count', default=0)

    # pugna fit NN args
    parser.add_argument("--nlayers", type=int, required=True,
                        help='number of hidden layers')
    parser.add_argument("--units", type=int, required=True,
                        nargs='+', help='list of units for each layer')
    parser.add_argument("--nscales", type=int, required=True,
                        nargs='+', help='list of Mscale scales for each layer')
    parser.add_argument("--scale-names", type=str, required=True, choices=['linear', 'base2'],
                        nargs='+', help='list of Mscale names for each layer')
    parser.add_argument("--activations", type=str, required=True,
                        nargs='+', help='list of activations for each layer')
    parser.add_argument("--dropouts", type=float, required=True,
                        nargs='+', help='list dropout probability for each layer')
    parser.add_argument("--batch-norms", type=str, required=True, nargs='+',
                        help='list of True/False to turn on/off BatchNormalization for each layer')
    parser.add_argument("--epochs", type=int, default=100,
                        help="number of epochs to run for")
    parser.add_argument("--batch-size", type=int,
                        help="mini-batch size. defaults to full dataset")
    parser.add_argument("--learning-rate", type=float, default=0.001,
                        help="learning rate. If using lrs then this is also the initial learning rate.")
    parser.add_argument("--lrs-name", type=str, default=None,
                        choices=["CosineDecayRestarts", "ReduceLROnPlateau"],
                        help="""name of tf.keras learning rate schedular (lrs) to use. If not given then
                        no lrs will be used.""")

    parser.add_argument("--lrs-cosine-first-decay-steps", type=int,
                        help="""See https://www.tensorflow.org/api_docs/python/tf/keras/experimental/CosineDecayRestarts
                        """)

    parser.add_argument("--lrs-reduceLRP-factor", type=float, default=0.8,
                        help="""See https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ReduceLROnPlateau
                        """)
    parser.add_argument("--lrs-reduceLRP-patience", type=int, default=20,
                        help="""See https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ReduceLROnPlateau
                        """)
    parser.add_argument("--lrs-reduceLRP-min_lr", type=float, default=1e-5,
                        help="""See https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ReduceLROnPlateau
                        """)

    parser.add_argument("--adam-amsgrad", action='store_true',
                        help='turn on or off amsgrad for adam optimizer')

    parser.add_argument("--gpu-devices", type=str,
                        help="set CUDA_VISIBLE_DEVICES. e.g. '0' or for multiple gpus '0,1,2,3' ")


    args = parser.parse_args()
    args_dict = vars(args)

    # https://stackoverflow.com/questions/14097061/easier-way-to-enable-verbose-logging
    level = min(2, args.verbose)  # capped to number of levels
    logger = pugna.logger.init_logger(level=level)
    logger.info("running pugna_fit")
    logger.info(f"pugna version: {pugna.__version__}")
    logger.info(f"verbosity turned on at level: {level}")

    logger.info("==========")
    logger.info("printing command line args")
    for k in args_dict.keys():
        logger.info(f"{k}: {args_dict[k]}")
    logger.info("==========")

    logger.info("setting CUDA_VISIBLE_DEVICES")
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices

    check_gpu()

    # sometimes multithreading can cause things to halt... this might help
    # set_threads()

    logger.info(f"making output dir: {args.output_dir}")
    os.makedirs(f"{args.output_dir}", exist_ok=True)

    logger.info(f"loading: {args.X_data_train}")
    X_train = np.load(args.X_data_train)

    logger.info(f"loading: {args.y_data_train}")
    y_train = np.load(args.y_data_train)

    logger.info(f"loading: {args.X_data_val}")
    X_val = np.load(args.X_data_val)

    logger.info(f"loading: {args.y_data_val}")
    y_val = np.load(args.y_data_val)

    logger.info(f"X_train.shape: {X_train.shape}")
    logger.info(f"y_train.shape: {y_train.shape}")
    logger.info(f"X_val.shape: {X_val.shape}")
    logger.info(f"y_val.shape: {y_val.shape}")

    logger.info("setting up model")

    input_dim = X_train.shape[1]
    output_dim = y_train.shape[1]

    logger.info("converting args.batch_norms string list to list of bools")
    batch_norms = convert_strlist_to_bool(args.batch_norms)

    logger.info("broadcast_inputs_to_length: units")
    units = broadcast_inputs_to_length(args.units, args.nlayers)
    logger.info(f"units: {units}")
    logger.info("broadcast_inputs_to_length: nscales")
    nscales = broadcast_inputs_to_length(args.nscales, args.nlayers)
    logger.info(f"nscales: {nscales}")
    logger.info("broadcast_inputs_to_length: activations")
    activations = broadcast_inputs_to_length(args.activations, args.nlayers)
    logger.info(f"activations: {activations}")
    logger.info("broadcast_inputs_to_length: dropouts")
    dropouts = broadcast_inputs_to_length(args.dropouts, args.nlayers)
    logger.info(f"dropouts: {dropouts}")
    logger.info("broadcast_inputs_to_length: batch_norms")
    batch_norms = broadcast_inputs_to_length(batch_norms, args.nlayers)
    logger.info(f"batch_norms: {batch_norms}")
    logger.info("broadcast_inputs_to_length: scale_names")
    scale_names = broadcast_inputs_to_length(args.scale_names, args.nlayers)
    logger.info(f"scale_names: {scale_names}")

    logger.info("calling 'build_model'")
    model = pugna.models.mscalednn.build_model(
        input_dim, output_dim, args.nlayers, units, nscales, activations, dropouts, batch_norms, scale_names)
    logger.info("'build_model' complete")

    logger.info("setting up optimizer")
    opt = tf.keras.optimizers.Adam(
        learning_rate=args.learning_rate, amsgrad=args.adam_amsgrad)

    loss = 'mse'
    logger.info(f"loss: {loss}")

    logger.info("compiling model")
    model.compile(loss=loss, optimizer=opt)
    logger.info("compiling model: complete")

    logger.info("setting up callbacks")
    callbacks = []

    if args.lrs_name:
        logger.info(f"using learning rate schedular: {args.lrs_name}")

        if args.lrs_name == "CosineDecayRestarts":
            learning_rate_fn = tf.keras.experimental.CosineDecayRestarts(args.learning_rate,
                                                                         args.lrs_cosine_first_decay_steps)
            logger.info("appending LRS to callbacks")
            callbacks.append(
                tf.keras.callbacks.LearningRateScheduler(learning_rate_fn))
        elif args.lrs_name == "ReduceLROnPlateau":
            learning_rate_fn = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
                                                                    factor=args.lrs_reduceLRP_factor,
                                                                    patience=args.lrs_reduceLRP_patience,
                                                                    min_lr=args.lrs_reduceLRP_min_lr)
            logger.info("appending LRS to callbacks")
            callbacks.append(learning_rate_fn)
        else:
            raise ValueError(f"lrs_name: {args.lrs_name} not valid")

    if args.batch_size is None:
        logger.info("batch-size is None using entire dataset")
        batch_size = X_train.shape[0]
    else:
        batch_size = args.batch_size

    logger.info(f"batch_size: {batch_size}")

    # fit model
    starttime = datetime.datetime.now()

    logger.info("running model.fit")
    history = model.fit(
        X_train,
        y_train,
        epochs=args.epochs,
        batch_size=batch_size,
        verbose=True,
        callbacks=callbacks,
        validation_data=(X_val, y_val)
    )

    endtime = datetime.datetime.now()
    duration = endtime - starttime

    logger.info("fit complete")
    logger.info(f"The time cost: {duration}")

    # save model
    logger.info("saving model")
    outname = os.path.join(args.output_dir, "model")
    pugna.model_utils.save_model_json(model, outname)
    pugna.model_utils.save_model_h5(model, outname)

    logger.info("saving history")
    outname = os.path.join(args.output_dir, "history.pickle")
    pugna.model_utils.save_history(history.history, outname)

    logger.info("saving duration")
    outname = os.path.join(args.output_dir, "duration.pickle")
    pugna.model_utils.save_datetime(duration, outname)

    last_loss = history.history['loss'][-1]
    logger.info(f"last loss: {last_loss}")
    last_val_loss = history.history['val_loss'][-1]
    logger.info(f"last val_loss: {last_val_loss}")

    # plot loss and anything else
    plot_loss(args.output_dir, history)

    if args.lrs_name:
        plot_learning_rate(args.output_dir, history)

    if X_train.shape[1] == 1:
        plot_1d_prediction(X_train, y_train, model, args.output_dir, "train")
    plot_prediction(X_train, y_train, model, args.output_dir, "train")
    plot_residual(X_train, y_train, model, args.output_dir, "train")

    if X_train.shape[1] == 1:
        plot_1d_prediction(X_val, y_val, model, args.output_dir, "val")
    plot_prediction(X_val, y_val, model, args.output_dir, "val")
    plot_residual(X_val, y_val, model, args.output_dir, "val")

    logger.info("done!")
