from pathlib import Path
from uuid import uuid4

import matplotlib.pyplot as plt
import torch
import torch.utils

from roheboam.engine.core.data import Data
from roheboam.engine.vision.utils.image import imread_grayscale, imread_rgb, plot_mask_on_image, to_uint8_image


class ImageSegmentationSample:
    """

    If label is not given, then it will be generated by taking the max value of each mask

    :return: [description]
    :rtype: [type]
    """

    def __init__(self, image, mask=None, weight_map=None, name=None):
        self.image = image
        self.mask = mask
        self.weight_map = weight_map
        self.name = str(uuid4()) if name is None else name

    @classmethod
    def create(
        cls,
        image_data=None,
        image_path=None,
        load_image_fn=imread_rgb,
        mask_data=None,
        mask_path=None,
        load_mask_fn=imread_grayscale,
        weight_map_data=None,
        weight_map_path=None,
        load_weight_map_fn=None,
        name=None,
    ):
        if name is None:
            name = Path(image_path).stem if image_path is not None else str(uuid4())

        image = Data(image_data, image_path, load_image_fn, name=name)

        if mask_data is None and mask_path is None:
            mask = None
        else:
            mask = Data(mask_data, mask_path, load_mask_fn, name=name)

        if weight_map_data is None and weight_map_path is None:
            weight_map = None
        else:
            weight_map = Data(weight_map_data, weight_map_path, load_weight_map_fn, name=name)

        return cls(image, mask, weight_map, name)

    @property
    def is_inference(self):
        return self.has_masks is False and self.has_weight_map is False

    @property
    def has_mask(self):
        return self.mask is not None

    @property
    def has_weight_map(self):
        return self.weight_map is not None

    @property
    def data(self):
        return (self.image.data, self.mask.data if self.has_mask else None, self.weight_map.data if self.has_weight_map else None, self.name)

    def plot(self, overlay_mask_with_image=True, figsize=(18, 18)):
        n_rows = 1
        n_cols = 1 + int(self.has_mask) + int(self.has_weight_map) + int(overlay_mask_with_image)
        _, axs = plt.subplots(n_rows, n_cols, figsize=figsize)

        plot_idx = 0
        axs[plot_idx].imshow(self.image.data)
        axs[plot_idx].set_title(self.name)
        plot_idx += 1

        if self.has_mask:
            axs[plot_idx].imshow(to_uint8_image(self.mask.data))
            plot_idx += 1

        if overlay_mask_with_image:
            axs[plot_idx].imshow(plot_mask_on_image(to_uint8_image(self.image.data), to_uint8_image(self.mask.data)))
            plot_idx += 1

        if self.has_weight_map:
            axs[plot_idx].imshow(self.weight_map.data)
            plot_idx + 1


lookup = {"ImageSegmentationSample": ImageSegmentationSample}
