# -*- coding: utf-8 -*-
"""
Created on Mon Mar  7 15:32:58 2022

@author: talha
"""

import numpy as np

class EMPatches(object):
    
    def __init__(self):
        pass

    def extract_patches(self, img, patchsize, overlap):
        '''
        Parameters
        ----------
        img : image to extract patches from in [H W Ch] format.
        patchsize :  size of patch to extract from image only square patches can be
                     extracted for now.
        overlap : overlap between patched in percentage a float between [0, 1].
    
        Returns
        -------
        img_patches : a list containing extracted patches of images.
        indices : a list containing indices of patches in order, whihc can be used 
                  at later stage for 'merging_patches'.
    
        '''
        
        maxWindowSize = patchsize
        overlapPercent = overlap
        
        height = img.shape[0]
        width = img.shape[1]
        
        windowSizeX = maxWindowSize
        windowSizeY = maxWindowSize
        # If the input data is smaller than the specified window size,
        # clip the window size to the input size on both dimensions
        windowSizeX = min(windowSizeX, width)
        windowSizeY = min(windowSizeY, height)
        
        # Compute the window overlap and step size
        windowOverlapX = int(math.floor(windowSizeX * overlapPercent))
        windowOverlapY = int(math.floor(windowSizeY * overlapPercent))
        
        stepSizeX = windowSizeX - windowOverlapX
        stepSizeY = windowSizeY - windowOverlapY
        
        # Determine how many windows we will need in order to cover the input data
        lastX = width - windowSizeX
        lastY = height - windowSizeY
        xOffsets = list(range(0, lastX+1, stepSizeX))
        yOffsets = list(range(0, lastY+1, stepSizeY))
        
        # Unless the input data dimensions are exact multiples of the step size,
        # we will need one additional row and column of windows to get 100% coverage
        if len(xOffsets) == 0 or xOffsets[-1] != lastX:
        	xOffsets.append(lastX)
        if len(yOffsets) == 0 or yOffsets[-1] != lastY:
        	yOffsets.append(lastY)
        
        img_patches = []
        indices = []
        
        for xOffset in xOffsets:
            for yOffset in yOffsets:
              if len(img.shape) == 3:
                  img_patches.append(img[(slice(yOffset, yOffset+windowSizeY, None),
                                          slice(xOffset, xOffset+windowSizeX, None))])
              else:
                  img_patches.append(img[(slice(yOffset, yOffset+windowSizeY),
                                          slice(xOffset, xOffset+windowSizeX))])
                  
              indices.append((yOffset, yOffset+windowSizeY, xOffset, xOffset+windowSizeX))
        
        return img_patches, indices
    
    
    def merge_patches(self, img_patches, indices):
        '''
        Parameters
        ----------
        img_patches : list containing image patches that needs to be joined, dtype=uint8
        indices : a list of indices generated by 'extract_patches' function of the format;
                    (yOffset, yOffset+windowSizeY, xOffset, xOffset+windowSizeX)
                  
        Returns
        -------
        Stitched image.
        '''
        
        orig_h = indices[-1][1]
        orig_w = indices[-1][3]
        
        rgb = True
        if len(img_patches[0].shape) == 2:
            rgb = False
        
        if rgb:
            empty_image = np.zeros((orig_h, orig_w, 3)).astype(np.uint8)
        else:
            empty_image = np.zeros((orig_h, orig_w)).astype(np.uint8)
            
        for i, indice in enumerate(indices):
            if rgb:
                empty_image[indice[0]:indice[1], indice[2]:indice[3], :] = img_patches[i].astype(np.uint8)
            else:
                empty_image[indice[0]:indice[1], indice[2]:indice[3]] = img_patches[i].astype(np.uint8)
                
        return empty_image
