# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_utility.ipynb.

# %% auto 0
__all__ = ['load_cls_labels', 'one_hot_encode', 'reverse_one_hot', 'mask_to_rgb', 'mask_map_classes',
           'mask_code_to_continues_indices', 'mask_continues_indices_to_codes', 'mask_continues_indices_to_rgb',
           'mask_continues_indices_to_idx', 'to_tensor', 'expand_to_32px']

# %% ../nbs/01_utility.ipynb 3
import numpy as np
import pandas as pd
import torch
from PIL import Image

# %% ../nbs/01_utility.ipynb 5
def load_cls_labels(file_path:str):
    "Load class names, codes, and RGB values"
    dtype = {"name": str, "code": int, "color_hex": str}
    cls_dict = pd.read_csv(file_path, dtype=dtype)
    cls_names = cls_dict['name'].tolist()
    cls_code_values = cls_dict['code'].values.tolist()
    
    
    # fill cls_dict with dummy values to allow easy color coding of segmentation masks
    for i in range(int(max(cls_dict['code']))): 
        if cls_dict[cls_dict['code']==i].empty:
            cls_dict.loc[len(cls_dict), 'code'] = i
            cls_dict.loc[len(cls_dict)-1, 'color_hex'] = '000000'
            cls_dict.loc[len(cls_dict)-1, 'name'] = 'dummy'
            cls_dict.loc[len(cls_dict)-1, 'label'] = 'Dummy'
    cls_dict = cls_dict.sort_values('code')  
    cls_color_hex = cls_dict[['color_hex']].values.tolist()
    cls_rgb = np.array([tuple(int(h[0][i:i+2], 16)/255 for i in (0, 2, 4)) for h in cls_color_hex] )
    cls_labels = cls_dict['label'].values.tolist()
    
    #cls_rgb_values = np.nan_to_num(np.array(cls_dict[['r','g','b']].values.tolist()))

    return cls_names, cls_code_values, cls_rgb, cls_labels

# %% ../nbs/01_utility.ipynb 10
def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values
        
    # Returns
        A 2D array with the same width and height as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis= -1)
    
    return semantic_map

def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    x = np.argmax(image, axis = -1)
    return x



# %% ../nbs/01_utility.ipynb 12
def mask_to_rgb(mask, cls_rgb_values):
    """
    Given a 1-channel array of class keys, color code the segmentation results.
    # Arguments
        mask: single channel array where each value represents the class key.
        cls_rgb_values: np.array of RGB values for keys

    # Returns
        Color coded PIL image for segmentation visualization
    """
    x = cls_rgb_values[mask.astype(int)]
    x = Image.fromarray((255 * x).astype(np.uint8))
    return x

# %% ../nbs/01_utility.ipynb 16
def mask_map_classes(mask, map):
    if map is not None:
        for m in map:
            mask[mask==m[0]]=m[1]
            return mask

# %% ../nbs/01_utility.ipynb 18
def mask_code_to_continues_indices(mask, cls_codes):
    """
    input: mask with n non-continuous class codes
    output: mask with class indices from 0 ... n-1
    """
    m_idx = np.zeros(mask.shape, dtype='uint8')
    for idx, org in enumerate(cls_codes):
        m_idx[mask==org] = idx
    return m_idx

def mask_continues_indices_to_codes(mask, cls_codes):
    return np.array(cls_codes)[mask].astype('uint8')

def mask_continues_indices_to_rgb(mask, cls_codes, cls_rgb_values):
    return mask_to_rgb(mask_continues_indices_to_codes(mask, cls_codes), cls_rgb_values)

# %% ../nbs/01_utility.ipynb 20
def mask_continues_indices_to_idx(mask, select_class_indices):
    return np.array(select_class_indices)[mask].astype('uint8')

# %% ../nbs/01_utility.ipynb 21
def to_tensor(x, **kwargs):
    "convert image (np.array) to float tensor, change order from WHC to CWH"
    if len(x.shape)==2: return x.astype('float32')
    return torch.tensor(x.transpose(2, 0, 1).astype('float32'))

def expand_to_32px(length):
    "return next larger integer divisible by 32"
    return length if length%32 == 0 else length + (32 - length%32)
