#pyrsgis/ml

from copy import deepcopy
import numpy as np
from ..raster import read
from sklearn.feature_extraction import image

# define a function to create image chips from single band array
def array2d_to_chips(data_arr, y_size=5, x_size=5):
    """
    Image chips from 2D array

    This function generates images chips from single band arrays. The image chips can
    be used as a direct input to deep learning models (eg. Convolutional Neural Network).

    Parameters
    ----------
    data_arr        : array
                      A 2D array from which image chips will be created.

    y_size          : integer
                      The height of the image chips. Ideally an odd number.

    x_size          : integer
                      The width of the image chips. Ideally an odd number.

    Returns
    -------
    image_chips     : array
                      A 3D array containing stacked image chips. The first index
                      represents each image chip and the size is equal to total number
                      of cells in the input array. The 2nd and 3rd index represent the
                      height and the width of the image chips.

    Examples
    --------
    >>> from pyrsgis import raster, ml
    >>> infile = r'E:/path_to_your_file/your_file.tif'
    >>> ds, data_arr = raster.read(infile)
    >>> image_chips = ml.array2d_to_chips(data_arr, y_size=5, x_size=5)
    >>> print('Shape of input array:', data_arr.shape)
    >>> print('Shape of generated image chips:', image_chips.shape)
    Shape of input array: (2054, 2044)
    Shape of generated image chips: (4198376, 5, 5)
    
    """
    image_chips = deepcopy(data_arr)
    image_chips = np.pad(image_chips, (int(y_size/2),int(x_size/2)), 'reflect')
    image_chips = image.extract_patches_2d(image_chips, (y_size, x_size))

    return(image_chips)

def imageChipsFromSingleBandArray(data_arr, y_size=5, x_size=5):
    image_chips = deepcopy(data_arr)
    image_chips = np.pad(image_chips, (int(y_size/2),int(x_size/2)), 'reflect')
    image_chips = image.extract_patches_2d(image_chips, (y_size, x_size))

    return(image_chips)
    
# define a function to create image chips from array
def array_to_chips(data_arr, y_size=5, x_size=5):
    """
    Image chips from raster array

    This function generates images chips from single or multi band raster arrays. The image
    chips can be used as a direct input to deep learning models (eg. Convolutional Neural Network).

    Parameters
    ----------
    data_arr        : array
                      A 2D or 3D raster array from which image chips will be created. This
                      should be similar as the one generated by ``pyrsgis.raster.read`` function.

    y_size          : integer
                      The height of the image chips. Ideally an odd number.

    x_size          : integer
                      The width of the image chips. Ideally an odd number.

    Returns
    -------
    image_chips     : array
                      A 3D or 4D array containing stacked image chips. The first index
                      represents each image chip and the size is equal to total number
                      of cells in the input array. The 2nd and 3rd index represent the
                      height and the width of the image chips. If the input array is a
                      3D array, then image_clips will be 4D where the 4th index will
                      represent the number of bands.

    Examples
    --------
    >>> from pyrsgis import raster, ml
    >>> infile = r'E:/path_to_your_file/your_file.tif'
    >>> ds, data_arr = raster.read(infile)
    >>> image_chips = ml.array_to_chips(data_arr, y_size=7, x_size=7)
    >>> print('Shape of input array:', data_arr.shape)
    >>> print('Shape of generated image chips:', image_chips.shape)
    Shape of input array: (6, 2054, 2044)
    Shape of generated image chips: (4198376, 7, 7, 6)
    
    """
    
    # if array is a single band image
    if len(data_arr.shape) == 2:
        return(array2d_to_chips(data_arr, y_size=y_size, x_size=x_size))

    # if array is a multi band image  
    elif len(data_arr.shape) > 2:
        data_arr = deepcopy(data_arr)
        
        for band in range(data_arr.shape[0]):
            temp_array = array2d_to_chips(data_arr[band, :, :], y_size=y_size, x_size=x_size)

            if band == 0:
                out_array = np.expand_dims(temp_array, axis=3)
            else:
                out_array = np.concatenate((out_array, np.expand_dims(temp_array, axis=3)), axis=3)

        return(out_array)
    
    # if shape of the image is less than two dimensions, raise error  
    else:
        raise Exception("Sorry, only two or three dimensional arrays allowed.")
    

def imageChipsFromArray(data_array, x_size=5, y_size=5):
    
    # if array is a single band image
    if len(data_array.shape) == 2:
        return(imageChipsFromSingleBandArray(data_array, x_size=x_size, y_size=y_size))

    # if array is a multi band image  
    elif len(data_array.shape) > 2:
        data_array = copy.copy(data_array)
        data_array = np.rollaxis(data_array, 0, 3)
        
        for band in range(data_array.shape[2]):
            temp_array = imageChipsFromSingleBandArray(data_array[:, :, band], x_size=x_size, y_size=y_size)

            if band == 0:
                out_array = np.expand_dims(temp_array, axis=3)
            else:
                out_array = np.concatenate((out_array, np.expand_dims(temp_array, axis=3)), axis=3)

        return(out_array)
    
    # if shape of the image is less than two dimensions, raise error  
    else:
        raise Exception("Sorry, only two or three dimensional arrays allowed.")

# define a function to create image chips from TIF file
def raster_to_chips(file, y_size=5, x_size=5):
    """
    Image chips from raster file

    This function generates images chips from single or multi band GeoTIFF file. The image
    chips can be used as a direct input to deep learning models (eg. Convolutional Neural Network).

    This is built on the ``pyrsgis.ml.array_to_chips`` function.

    Parameters
    ----------
    file            : string
                      Name or path of the GeoTIFF file from which image chips will be created. 

    y_size          : integer
                      The height of the image chips. Ideally an odd number.

    x_size          : integer
                      The width of the image chips. Ideally an odd number.

    Returns
    -------
    image_chips     : array
                      A 3D or 4D array containing stacked image chips. The first index
                      represents each image chip and the size is equal to total number
                      of cells in the input array. The 2nd and 3rd index represent the
                      height and the width of the image chips. If the input file is a
                      multiband raster, then image_clips will be 4D where the 4th index will
                      represent the number of bands.

    Examples
    --------
    >>> from pyrsgis import raster, ml
    >>> infile_2d = r'E:/path_to_your_file/your_2d_file.tif'
    >>> image_chips = ml.raster_to_chips(data_arr, y_size=7, x_size=7)
    >>> print('Shape of single band generated image chips:', image_chips.shape)
    Shape of single bandgenerated image chips: (4198376, 7, 7)

    Not that here the shape of the input raster file is 2054 rows by 2044 columns.
    If the raster file is multiband:

    >>> infile_3d = r'E:/path_to_your_file/your_3d_file.tif'
    >>> image_chips = ml.raster_to_chips(infile_3d, y_size=7, x_size=7)
    >>> print('Shape of multiband generated image chips:', image_chips.shape)
    Shape of multiband generated image chips: (4198376, 7, 7, 6)
    
    """
    
    ds, data_arr = read(file)

    return(array_to_chips(data_arr, y_size=y_size, x_size=x_size))

def imageChipsFromFile(infile, y_size=5, x_size=5):
    ds, data_arr = read(infile)

    return(imageChipsFromArray(data_arr, y_size=y_size, x_size=x_size))
