import os
import os.path as osp
from typing import Dict, List

import pytorch_lightning as pl
import torch  # do not remove
import torch.nn as nn
from torchvision.utils import make_grid

from ..tools.logger import Logger


class BasePipeline(nn.Module):
    def __init__(self):
        super().__init__()
        self._epoch = 0
        self._log_dir = ""
        self._train_step = 0
        self._validation_step = 0

    def epoch(self) -> int:
        return self._epoch

    def train_step(self) -> int:
        return self._train_step

    def validation_step(self) -> int:
        return self._validation_step

    def log_dir(self, name: str = "") -> str:
        path = osp.join(self._log_dir, name)
        os.makedirs(path, exist_ok=True)
        return path

    def run_train(self, batch):
        raise NotImplementedError("not implemented!")

    def run_validation(self, batch):
        raise NotImplementedError("not implemented!")

    def run_test(self, batch):
        raise NotImplementedError("not implemented!")

    def training_epoch_end(self):
        pass

    def validation_epoch_end(self):
        pass

    def test_epoch_end(self):
        pass


class LightningModel(pl.LightningModule):
    def __init__(self, model: BasePipeline, optimizer_cfg, scheduler_cfg, log_dir: str) -> None:
        super().__init__()

        self.model: BasePipeline = model
        self.optimizer_cfg = optimizer_cfg
        self.scheduler_cfg = scheduler_cfg
        self.model._log_dir = log_dir

    def configure_optimizers(self):
        """配置优化器"""

        # optimizer
        optimizer_cls = f"torch.optim.{self.optimizer_cfg.target}"
        Logger.info(f"optimizer:{optimizer_cls}")
        optimizer = eval(optimizer_cls)(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            **self.optimizer_cfg.params,
        )

        # scheduler
        scheduler_cls = f"torch.optim.lr_scheduler.{self.scheduler_cfg.target}"
        Logger.info(f"scheduler:{scheduler_cls}")
        scheduler = eval(scheduler_cls)(optimizer, **self.scheduler_cfg.params)

        return ({"optimizer": optimizer, "lr_scheduler": scheduler},)

    # def log_img(self, img_dict_list: List[Dict]):
    #     """tensorboard image"""
    #     self.log_img_cnt += 1

    #     if img_dict_list is not None:
    #         tensorboard = self.logger.experiment[0]
    #         for img_dict in img_dict_list:
    #             name = img_dict["name"]
    #             img = img_dict["img"]
    #             nrow = img_dict["nrow"] if "nrow" in img_dict else 8
    #             tensorboard.add_image(name, make_grid(img, nrow=nrow), self.log_img_cnt)

    def training_step(self, batch, batch_idx):
        # run_train
        self.model._epoch = self.current_epoch
        self.model._train_step += 1

        total_loss, loss_dict, _ = self.model.run_train(batch)

        for name, value in loss_dict.items():
            self.log(f"train/{name}", value, on_step=True,
                     on_epoch=False, sync_dist=True)

        for name, value in loss_dict.items():
            self.log(f"train_epoch/{name}", value,
                     on_step=False, on_epoch=True, sync_dist=True)

        return total_loss

    def training_epoch_end(self, loss_list: List[dict]) -> None:
        if isinstance(self.model, BasePipeline):
            self.model.training_epoch_end()

    def validation_step(self, batch, batch_idx):
        self.model._validation_step += 1

        batch_size = batch[list(batch.keys())[0]].shape[0]
        total_loss, loss_dict, _ = self.model.run_validation(batch)

        for name, value in loss_dict.items():
            self.log(f"validation/{name}", value, on_epoch=True, sync_dist=True)

        self.log(
            "validation_loss",
            total_loss,
            prog_bar=True,
            sync_dist=True
        )
        loss_dict["_batch_size"] = batch_size
        return loss_dict

    def validation_epoch_end(self, loss_list: List[dict]):
        if isinstance(self.model, BasePipeline):
            self.model.validation_epoch_end()
        total_loss = dict()

        total_cnt = 0
        for loss_dict in loss_list:
            batch_size = loss_dict["_batch_size"]
            total_cnt += batch_size

            for name, loss in loss_dict.items():
                if name == "_batch_size":
                    continue

                if name not in total_loss:
                    total_loss[name] = 0
                total_loss[name] += loss * batch_size

        text = f"epoch={self.current_epoch:03d}, "
        for name in total_loss:
            loss = total_loss[name] / total_cnt
            text += f"{name}={loss:.5f}, "

        Logger.info(text[:-2])

    def test_step(self, batch, batch_idx):
        batch_size = batch[list(batch.keys())[0]].shape[0]
        total_loss, loss_dict, img_dict_list = self.model.run_test(batch)

        for name, value in loss_dict.items():
            self.log(f"test/{name}", value, on_epoch=True, sync_dist=True)

        self.log(
            "test_loss",
            total_loss,
            prog_bar=True,
            sync_dist=True
        )
        loss_dict["_batch_size"] = batch_size
        return loss_dict

    def test_epoch_end(self, loss_list: List[dict]):
        if isinstance(self.model, BasePipeline):
            self.model.test_epoch_end()
        total_loss = dict()

        total_cnt = 0
        for loss_dict in loss_list:
            batch_size = loss_dict["_batch_size"]
            total_cnt += batch_size

            for name, loss in loss_dict.items():
                if name == "_batch_size":
                    continue

                if name not in total_loss:
                    total_loss[name] = 0
                total_loss[name] += loss * batch_size

        text = f"epoch={self.current_epoch:03d}, "
        for name in total_loss:
            loss = total_loss[name] / total_cnt
            text += f"{name}={loss:.5f}, "

        Logger.info(text[:-2])
