import argparse
import os
import os.path as osp
import sys

import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.trainer import Trainer

from ..tools import Logger, get_time_str
from ..tools.system_info import get_cpu_info, get_gpu_info, get_package_info
from .config import (get_checkpoint_dir, get_log_dir, instantiate_from_config,
                     load_config, save_config)
from .dataset import LightningDataset
from .model import LightningModel
from ..files import save_json
import torch.nn as nn


def get_parameters_num(model) -> dict:
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel()
                        for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}


def export_model_parameters(model: nn.Module) -> dict:
    """模型参数量报告"""
    parameters_dict = dict()
    for name in dir(model):
        try:
            d = get_parameters_num(getattr(model, name))
            if d["Total"] > 0:
                parameters_dict[name] = d
        except Exception as e:
            pass
    parameters_dict["[ALL]"] = get_parameters_num(model)
    for name in parameters_dict:
        rate = parameters_dict[name]["Trainable"] / \
            parameters_dict["[ALL]"]["Trainable"] * 100.0
        parameters_dict[name]["Rate(Trainable)"] = f"{rate:.2f}%"
    return parameters_dict


def train(config_filename: str, ckpt_path: str) -> None:
    """从配置文件启动训练

    Args:
        config_filename (str): 配置文件
    """

    root_path = osp.abspath(config_filename).split("configs")[0]
    sys.path.insert(0, root_path)
    Logger.info(f"sys.path.insert(0, {root_path})")

    config = load_config(config_filename)

    # log_dir
    log_dir = get_log_dir(config_filename)
    if osp.exists(log_dir):
        Logger.warn(f"exists: {log_dir}")
        Logger.error("please create a new config!")
        return
    os.makedirs(log_dir, exist_ok=True)
    Logger.info(f"mkdir: {log_dir}")

    environments_filename = osp.join(root_path, log_dir, "environments.yaml")
    os.system(f"conda env export > {environments_filename}")

    Logger.logfile(osp.join(log_dir, "train.log"), clear=True)

    Logger.info(f"task: {config.task}")

    # config
    new_config = osp.join(log_dir, "config.yaml")
    save_config(new_config, config)

    # log file
    exp_info = {
        "machine": config.machine,
        "time": get_time_str(),
        "seed": config.seed,
        "config": config_filename,
        "precision": config.precision,
        "batch_size": config.data.train.data_loader.batch_size,
    }
    exp_info.update(get_gpu_info())
    exp_info.update(get_cpu_info())
    exp_info.update(get_package_info())
    save_json(osp.join(log_dir, "exp_info.json"), exp_info)

    # seed
    seed_everything(config.seed)
    Logger.info(f"seed:{config.seed}")
    Logger.info(f"strategy:{config.strategy}")
    Logger.info(f"accelerator:{config.accelerator}")
    Logger.info(f"devices:{config.devices}")
    Logger.info(f"precision:{config.precision}")
    Logger.info(f"batch_size:{config.data.train.data_loader.batch_size}")
    Logger.info(f"num_workers:{config.data.train.data_loader.num_workers}")
    Logger.info(f"save_top_k: {config.checkpoint.save_top_k}")
    Logger.info(f"save_last: {config.checkpoint.save_last}")

    path, version = osp.split(log_dir)
    save_dir, task_name = osp.split(path)

    trainer = Trainer(
        precision=config.precision,
        accelerator=config.accelerator,
        devices=config.devices,
        sync_batchnorm=True,
        check_val_every_n_epoch=config.check_val_every_n_epoch,
        accumulate_grad_batches=config.accumulate_grad_batches,
        max_epochs=config.max_epochs,
        logger=[
            pl.loggers.TensorBoardLogger(
                save_dir=save_dir,
                name=task_name,
                version=version,
            ),
        ],
        callbacks=[
            LearningRateMonitor(logging_interval="step"),
            ModelCheckpoint(
                dirpath=get_checkpoint_dir(config_filename),
                monitor="validation_loss",
                filename="{epoch:03d}_{eval_loss:.5f}",
                save_last=config.checkpoint.save_last,
                mode=config.checkpoint.mode,
                save_top_k=config.checkpoint.save_top_k,
                save_weights_only=config.checkpoint.save_weights_only,
            ),
        ],
        strategy=config.strategy,
        log_every_n_steps=config.log_every_n_steps,
        profiler="simple",
    )

    Logger.info("prepare to build model...")
    model = instantiate_from_config(config.model)

    # 保存模型参数量报告
    save_json(osp.join(log_dir, "parameters.json"), export_model_parameters(model))

    lightning_model = LightningModel(
        model=model, optimizer_cfg=config.optimizer, scheduler_cfg=config.scheduler, log_dir=log_dir)

    Logger.info("prepare to build dataset...")
    data = LightningDataset(**config.data)

    Logger.info("prepare to train...")
    Logger.warn(
        f"[Please Run] tensorboard --logdir={osp.join(save_dir, task_name)}")

    # 取消这个功能
    cProfile_enable = False
    if cProfile_enable and config.cProfile_enable:
        import cProfile

        Logger.info(f"cProfile_enable: {config.cProfile_enable}")
        if config.max_epochs != 1:
            Logger.warn(
                f"Proposed max_epochs=1, now max_epochs={config.max_epochs}")
        profiler = cProfile.Profile()
        profiler.enable()

    if ckpt_path:
        Logger.warn(f"resume from:{ckpt_path}")
    trainer.fit(lightning_model, data, ckpt_path=ckpt_path)

    if cProfile_enable and config.cProfile_enable:
        profiler.disable()
        filename = osp.join(log_dir, "out.prof")
        profiler.dump_stats(filename)
        Logger.info(f"save: {filename}")
        Logger.warn(f"Please Run: snakeviz {filename}")
        svg = filename.replace(".prof", "svg")
        Logger.warn(f"Please Run: flameprof {filename} > {svg}")


def validate(config_filename: str, ckpt_path: str, mode: str, batch_size: int = -1) -> None:
    """validate

    Args:
        config_filename (str): 配置文件
    """

    root_path = osp.abspath(config_filename).split("configs")[0]
    sys.path.insert(0, root_path)
    Logger.info(f"sys.path.insert(0, {root_path})")

    config = load_config(config_filename)

    # log_dir
    log_dir = get_log_dir(config_filename)
    os.makedirs(log_dir, exist_ok=True)
    Logger.logfile(osp.join(log_dir, f"{mode}.log"), clear=True)

    # log file
    Logger.info(f"config: {config_filename}")

    # seed
    seed_everything(config.seed)
    Logger.info(f"seed:{config.seed}")

    trainer = Trainer(
        precision=config.precision,
        accelerator=config.accelerator,
        devices=config.devices,
        sync_batchnorm=True,
        logger=False,
    )

    Logger.info("prepare to build model...")
    model = instantiate_from_config(config.model)
    lightning_model = LightningModel(
        model=model, optimizer_cfg=config.optimizer, scheduler_cfg=config.scheduler, log_dir=log_dir)

    Logger.info("prepare to build dataset...")
    if batch_size > 0:
        if mode == "test":
            old_batch_size = config.data.test.data_loader.batch_size
            config.data.test.data_loader.batch_size = batch_size
        else:
            old_batch_size = config.data.validation.data_loader.batch_size
            config.data.test.data_loader.batch_size = batch_size
        if batch_size != old_batch_size:
            Logger.warn(f"change batch size: {old_batch_size} to {batch_size}")
    data = LightningDataset(**config.data)

    if ckpt_path:
        Logger.info(f"load ckpt from:{ckpt_path}")
    else:
        Logger.warn("no checkpoint file")
    if mode == "test":
        trainer.test(lightning_model, data, ckpt_path=ckpt_path)
    else:
        trainer.validate(lightning_model, data, ckpt_path=ckpt_path)


def run_main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("config", default="config.yaml",
                        type=str, help="config filename")
    parser.add_argument("--mode", default="train", type=str,
                        help="train or validate(val) or test")
    parser.add_argument("--ckpt", default="", type=str, help="checkpoint path")
    parser.add_argument("--bs", default=-1, type=int,
                        help="batch size(only for validate(val) or test)")

    args = parser.parse_args()

    if not osp.exists(args.config):
        Logger.error(f"not exists: {args.config}")
        sys.exit(0)

    mode = args.mode
    ckpt_path = args.ckpt
    batch_size = args.bs
    if ckpt_path:
        if not osp.exists(ckpt_path):
            Logger.error(f"not exists: {ckpt_path}")
            sys.exit(0)

    if mode == "train":
        Logger.info("[mode]: train")
        train(args.config, ckpt_path)
    elif mode == "validate" or mode == "val":
        Logger.info("[mode]: validate")
        validate(args.config, ckpt_path,
                 mode="validate", batch_size=batch_size)
    elif mode == "test":
        Logger.info("[mode]: test")
        validate(args.config, ckpt_path, mode="test", batch_size=batch_size)
    else:
        Logger.error("bad mode")
        sys.exit(0)


def run_hpc() -> None:
    """在华科服务器上运行"""
    parser = argparse.ArgumentParser()
    parser.add_argument("config", default="config.yaml",
                        type=str, help="config filename")
    parser.add_argument("--mode", default="train", type=str,
                        help="train or validate(val) or test")
    parser.add_argument("--ckpt", default="", type=str, help="checkpoint path")
    parser.add_argument("--bs", default=-1, type=int,
                        help="batch size(only for validate(val) or test)")

    args = parser.parse_args()

    if not osp.exists(args.config):
        Logger.error(f"not exists: {args.config}")
        sys.exit(0)

    config_filename = args.config
    config = load_config(config_filename)
    hpc_cfg = config["hpc"]
    gpus = len(config["devices"])

    text = f"""#!/bin/bash

#SBATCH --job-name={hpc_cfg.job_name}
#SBATCH --nodes=1
#SBATCH --ntasks={hpc_cfg.cpus}
#SBATCH --gres=gpu:{gpus}
#SBATCH --comment={hpc_cfg.comment}
#SBATCH --partition=gpu

export LD_LIBRARY_PATH=/usr/local/nvidia/lib
module load gcc-7.5.0-gcc-4.8.5-of6wn6o

cd {hpc_cfg.path}
run {config_filename} --mode {args.mode} --ckpt {args.ckpt} --bs {args.bs}

"""
    with open(hpc_cfg.job_name + ".job", "w") as fw:
        fw.write(text)

    os.system(f"sbatch {hpc_cfg.job_name}.job")
