#!/usr/bin/env python3

# ----------------------------------------------------------------------
# Created: fre jun 18 02:03:33 2021 (+0200)
# Last-Updated:
# Filename: chef
# Author: Yinan Yu
# Description:
# ----------------------------------------------------------------------

import os, sys, glob
import ast
from pprint import pprint
import argparse
from typing import Union

from alex.alex import compare, const, util, dsl_parser
from alex.annotators import code_gen


def diff(configs, to_png, verbose, dpi, mode):
    if to_png:
        path = to_png
    else:
        path = None

    if isinstance(verbose, str):
        verbose = ast.literal_eval(verbose)
    config1, config2 = configs
    print("Comparing %s and %s" % (config1, config2))
    if mode == "diff":
        cost, operations = compare.diff(config1, config2, render_to=path, dpi=dpi)
    elif mode == "dist":
        cost, operations = compare.dist(config1, config2, render_to=path, dpi=dpi)
    else:
        raise Exception("Mode %s not recgonized" % mode)
    if verbose:
        for operation in operations:
            if operation[1][0] != "MATCH":
                print(operation)
    if verbose:
        print("Set --verbose False to avoid output")


def ls(option):
    # paths = sorted(glob.glob(const.COMPONENT_BASE_PATH+"/[!.]*.yml"))
    # components = list(map(lambda x: x.split("/")[-1].split(".")[0], paths))

    if option == "components":
        pprint(list(filter(lambda x: x!="root", const.ALL_COMPONENTS)))
    elif option == "recipes":
        pprint(list(filter(lambda x: x!="root", const.ALL_RECIPES)))
    elif option == "ingredients":
        pprint(list(const.INGREDIENT_TYPES.keys()))
    elif option == "initializers":
        pprint(list(const.ALL_INITIALIZERS.keys()))
    elif option == "regularizers":
        pprint(list(const.REGULARIZERS.keys()))
    elif option == "optimizers":
        pprint(list(const.OPTIMIZERS.keys()))
    elif option == "losses":
        pprint(list(const.LOSSES.keys()))
    elif option == "functions":
        print("type", "name", "inputs", "repeat", "hyperparams", "visible")


def inspect(component):
    config = os.path.join(const.COMPONENT_BASE_PATH,
                          component+".yml")
    if os.path.exists(config):
        hyperparam = util.read_yaml(config)
        pprint(hyperparam)
    else:
        if component in const.ALL_COMPONENTS:
            print("Ingredient %s has no hyperparameter! Alex loves those!" % component)
        else:
            print("Oh no ingredient %s does not exist in Alex yet!\n"
                  "Help us improve please?" % component)


def codegen(engine, config, out_dir, filename,
            ckpt_from=None, ckpt_to=None):
    if not config:
        print("Must give a network configuration")
        sys.exit()
    else:
        config = config[0]

    try:
        if ckpt_from is not None:
            load_dir = "/".join(ckpt_from.split("/")[:-1])
            load_ckpt = ckpt_from.split("/")[-1]
            ckpt_from = [load_dir, load_ckpt]

        if ckpt_to is not None:
            save_dir = "/".join(ckpt_to.split("/")[:-1])
            save_ckpt = ckpt_to.split("/")[-1]
            ckpt_to = [save_dir, save_ckpt]

        filepath = os.path.join(out_dir, filename)
        code_gen.generate_python(filepath,
                                 config,
                                 engine=engine,
                                 dirname=out_dir,
                                 load_ckpt=ckpt_from,
                                 save_ckpt=ckpt_to)
        print("Generated code is written to file %s" % filepath)
    except Exception as err:
        print(err)


def render(config, type, path, level):
    if not config:
        print("Must give a network configuration")
        sys.exit()
    else:
        config = config[0]

    if type == "ast":
        dsl_parser.make_ast_from_yml(config, path)

    elif type == "graph":
        dsl_parser.make_graph_from_yml(config, path, level)

    print("Image written to file %s" % path)


def merge(config, ckpt):
    pass


def main(fn, kwargs=dict()):
    fns[fn](**kwargs)


if __name__=="__main__":

    if len(sys.argv) < 2:
        print("Choose one of the following: diff, ls, inspect, codegen, render; or -h for help")
        sys.exit()
    fn = sys.argv[1]
    fns = {"diff": diff,
           "ls": ls,
           "inspect": inspect,
           "codegen": codegen,
           "render": render}

    parser = argparse.ArgumentParser(description="Alex network analyzer")
    subparsers = parser.add_subparsers(help="")

    diff_parser = subparsers.add_parser("diff",
                                        help="Change log between two networks, e.g. alex-nn diff example_config_1.yml example_config_2.yml")
    diff_parser.add_argument("configs", metavar="Networks to compare", type=str, nargs=2,
                             help="The orignal and modified network configurations")

    diff_parser.add_argument("--to_png", metavar="Save diff to a png file",
                             type=str, nargs="?", default=True,
                             help="Save diff to the specified png file")

    diff_parser.add_argument("--verbose", metavar="Print the diffs (True/False)",
                             type=str, nargs="?", default=True,
                             help="Show the diffs")

    diff_parser.add_argument("--mode", metavar="diff/dist",
                             type=str, nargs="?", default="diff",
                             help="Difference or distance")

    diff_parser.add_argument("--dpi", metavar="dpi of the image",
                             default=800,
                             type=int, nargs="?",
                             help="Resolution of the image")

    ls_parser = subparsers.add_parser("ls", help="List information, e.g. alex-nn ls [functions, coponents, recipes, ingredients, initializers, regularizers, optimizers, losses]")
    ls_parser.add_argument("option", metavar="option", type=str, nargs="?",
                           default="components",
                           help="What to ls?")

    inspect_parser = subparsers.add_parser("inspect", help="Inspect hyperparameters of an ingredient, e.g. alex-nn inspect conv")
    inspect_parser.add_argument("component", metavar="choose a component to inspect, e.g. conv, relu, etc",
                                type=str,
                                help="Which component to inspect?")

    codegen_parser = subparsers.add_parser("codegen", help="Generate python code, e.g. alex-nn codegen example_config.yml")

    codegen_parser.add_argument("config",
                                metavar="Network configuration file path",
                                type=str,
                                nargs=1,
                                help="Network configuration")

    codegen_parser.add_argument("--engine",
                                default="pytorch",
                                metavar="Currently support tf and pytorch",
                                type=str,
                                help="Which framework?")

    codegen_parser.add_argument("--out_dir", metavar="Output dir", default="./",
                                type=str,
                                help="Python file will be written to this dir")

    codegen_parser.add_argument("--filename", metavar="Output file name",
                                type=str,
                                default="generated.py",
                                help="Python file name (not the full path)")

    codegen_parser.add_argument("--ckpt_from", metavar="Checkpoint path to load from",
                                type=str,
                                help="Load from checkpoint")

    codegen_parser.add_argument("--ckpt_to", metavar="Checkpoint path to save to",
                                type=str,
                                help="Save to checkpoint")

    render_parser = subparsers.add_parser("render",
                                          help="Render ast or graph from a config")

    render_parser.add_argument("config",
                               metavar="Network configuration file path",
                               nargs=1,
                               type=str,
                               help="Network configuration")

    render_parser.add_argument("--type", metavar="ast/graph",
                               type=str,
                               nargs="?",
                               default="ast",
                               help="What to render?")

    render_parser.add_argument("--path", metavar="Path",
                               type=str, nargs="?", default="./network.png",
                               help="Path where the png file goes")

    render_parser.add_argument("--level", metavar="Level of the graph",
                               type=int, nargs="?", default=2,
                               help="For graph you can hide the details of the recipe")

    args = parser.parse_args()
    main(fn, vars(args))
