#pyrsgis/convert

import os, glob
import numpy as np
import pandas as pd
import csv
from ..raster import read
from ..raster import export
from .. import doc_address


def changeDimension(arr):
    """
    The pyrsgis.convert.changeDimension() module has moved to
    pyrsgis.convert.array_to_table. Please check the documentation.
    """
    print('The "changeDimension()" function has moved to "array_to_table()" and will be deprecated in future versions. ' +
      'Please check the pyrsgis documentation at %s for more details.' % (doc_address))

    if len(arr.shape) == 3:
        layer, row, col = arr.shape
        temparr = np.random.randint(1, size=(row*col, layer))
        for n in range(0, layer):
            temparr[:,n] = np.reshape(arr[n,:,:], (row*col,))
        return(temparr)
    if len(arr.shape) == 2:
        row, col = arr.shape
        temparr = np.reshape(arr, (row*col,))
        return(temparr)
    else:
        print("Inconsistent shape of input array.\n2-d or 3-d array expected.")

def array_to_table(arr):
    """
    Convert 2D or 3D array to table

    The function converts single band or multiband raster array to a table where
    columns represents the input bands and each row represents a cell.

    Parameters
    ----------
    arr         : numpy array
                  A single band (2D) or multiband (3D) raster array. Please note that for
                  multiband raster arrays, the band index should be in the beginning, similar
                  to the one generated by the pyrsgis.raster.read function.

    Examples
    --------
    >>> from pyrsgis import raster, convert
    >>> input_file = r'E:/path_to_your_file/raster_file.tif'
    >>> ds, data_arr = raster.read(input_file)
    >>> data_table = convert.array_to_table(data_arr)

    Now check the shape of the input and reshaped arrays.

    >>> print('Shape of the input array', data_arr.shape)
    >>> print('Shape of the reshaped array:', data_table.shape)
    Shape of the input array: (6, 800, 400)
    Shape of the reshaped array: (320000, 6)

    Here, the input was a six band multispectral raster image.
    Same method applies for single band rasters also.
    
    """
    
    if len(arr.shape) == 3:
        layer, row, col = arr.shape
        temparr = np.random.randint(1, size=(row*col, layer))
        for n in range(0, layer):
            temparr[:,n] = np.reshape(arr[n,:,:], (row*col,))
        return(temparr)
    if len(arr.shape) == 2:
        row, col = arr.shape
        temparr = np.reshape(arr, (row*col,))
        return(temparr)
    else:
        print("Inconsistent shape of input array.\n2-d or 3-d array expected.")

        
def table_to_array(table, n_rows=None, n_cols=None):
    """
    Convert tablar array to 2D or 3D array

    The function converts a table where columns represents the input bands
    and each row represents a cell to a single band or multiband raster array.

    Parameters
    ----------
    table       : numpy array
                  A 2D array where rows represent cells of to be generated
                  raster array and each column represents band. This is similar
                  to the one generated by the pyrsgis.convert.array_to_table function.

    Examples
    --------
    >>> from pyrsgis import raster, convert
    >>> input_file = r'E:/path_to_your_file/raster_file.tif'
    >>> ds, data_arr = raster.read(input_file)
    >>> data_table = convert.array_to_table(data_arr)
    >>> print('Shape of the input array:', data_arr.shape)
    >>> print('Shape of the reshaped array:', data_table.shape)
    Shape of the input array: (6, 800, 400)
    Shape of the reshaped array: (320000, 6)

    ...some analysis/processing that you may want to do and generate more columns,
    say two more columns. Then:

    >>> new_data_arr = convert.table_to_array(data_table, n_rows=ds.RasterYSize, n_cols=ds.RasterXSize)
    >>> print('Shape of the array with newly added bands:', new_data_arr.shape)
    Shape of the array with newly added bands: (8, 800, 400)

    If you want to reshape only the new band(s), then:

    >>> new_data_arr = convert.table_to_array(data_table[:, -2:], n_rows=ds.RasterYSize, n_cols=ds.RasterXSize)
    >>> print('Shape of the array with newly added bands:', new_data_arr.shape)
    Shape of the array with newly added bands: (2, 800, 400)
            
    """
    
    if len(table.shape) > 2:
        print('A three dimensional array was provided. Please provied a 1D or 2D array. ' + 
              'Please check the pyrsgis documentation at %s' % (doc_address))
        return None
    
    elif len(table.shape) > 1:
        n_bands = table.shape[1]

        if n_bands > 1:
            out_arr = np.zeros((n_bands, n_rows, n_cols))

            for n in range(0, n_bands):
                out_arr[n, :, :] = np.reshape(table[:, n], (n_rows, n_cols))
    else:
        out_arr = np.reshape(table, (n_rows, n_cols))

    return out_arr


def raster_to_csv(path, filename='pyrsgis_rastertocsv.csv', negative=True, remove=[], badrows=True):
    """
    Convert raster to a tabular CSV file

    This function converts a single or multiband raster or rasters present in a
    given directory to a CSV file. Each row in the output CSV file represents a
    cell and columns represent band(s) of the input raster(s).

    Parameters
    ----------
    path       : string
                 Path to a file or a directory containing raster file(s).
                 
    filename   : string
                 Output CSV file name, with or without path.
                 
    negative   : boolean
                 Whether to retain negative values or not. If False, all negative
                 values will be forced to zero in the output CSV. This maybe useful in
                 some cases as NoData cells in raster files are often negative.

    remove     : list
                 A list of values that you want to remove from the exported table. If a list
                 is passed, all the values of the list will be converted to zero in the raster
                 before transforming to table. Please note that in the backend, this step
                 happens before bad rows removal.

    badrows    : True
                 Whether to retain rows in the CSV where all cells have zero value.
                 This can be helpful since raster layers masked using a non-rectangular
                 polygon may have unnecessary NoData cells. In such cases, if all the bands
                 have a cell value of zero and are not relevant, this parameter can help in
                 reducing the size of the data. Please note that cells converted to zero by
                 passing the 'negative' and 'remove' arguments will also be considered as bad cells.

    Examples
    --------
    >>> from pyrsgis import convert

    If you want to convert a single raster file (single or multiple bands):
    
    >>> input_file = r'E:/path_to_your_file/raster_file.tif'
    >>> output_file = r'E:/path_to_your_file/tabular_file.csv'
    >>> convert.raster_to_csv(input_file, filename=output_file)

    If you want to convert all files in a directory, please ensure that all
    rasters in the directory have the same extent, cell size and geometry.
    The files in the directory can be a mix of single and multiband rasters.

    >>> input_dir = r'E:/path_to_your_file/'
    >>> output_file = r'E:/path_to_your_file/tabular_file.csv'
    >>> convert.raster_to_csv(input_dir, filename=output_file)

    If you want to remove negative values, simply pass the 'negative' argument to False:

    >>> convert.raster_to_csv(input_dir, filename=output_file, negative=False)

    If you want to remove specific values, use this:

    >>> convert.raster_to_csv(input_dir, filename=output_file, remove=[10, 54, 127])

    If you want to remove bad rows, use the following line:

    >>> convert.raster_to_csv(input_dir, filename=output_file, badrows=False)

    """
    
    data_df = pd.DataFrame()
    names = []

    # If an input file is provided
    if os.path.splitext(path)[-1].lower()[-3:] == 'tif':
        ds, arr = read(path)
        header = os.path.splitext(os.path.basename(path))[0]

        if ds.RasterCount > 1:
            for n in range(0, ds.RasterCount):
                data_df['%s@%d' % (header, n+1)] = np.ravel(arr[n, :, :])
        else:
            data_df['%s@%d' % (header, 1)] = np.ravel(arr)

    # If a directory is provided
    else:
        os.chdir(path)

        for file in glob.glob("*.tif"):
            print('Converting %s..' % (file))
            header = os.path.basename(file)
            
            ds, arr = read(file)
            n_bands = ds.RasterCount

            if n_bands > 1:
                for n in range(0, n_bands):
                    data_df['%s@%d' % (header, n+1)] = np.ravel(arr[n, :, :])
            else:
                data_df['%s@%d' % (header, 1)] = np.ravel(arr)

    # Based on passed arguments, check for negatives and values to be removed
    if negative==False:
        data_df[data_df < 0] = 0

    for value in range(0, len(remove)):
        data_df[data_df == remove[value]] = 0

                
    if badrows == False:
        data_df = data_df[(data_df.T != 0).any()]

    # export the file
    data_df.to_csv(filename, index=False)


def csv_to_raster(csvfile, ref_raster, cols=[], stacked=True, filename=None,
                  dtype='default', compress=None, nodata=-9999):
    """
    Convert a CSV file to raster

    Parameters
    ----------
    csvfile       : string
                    CSV file name. Please provide full path if file is not located
                    in the working directory.

    ref_raster    : string
                    A reference raster file for target cell size, extent, projection, etc.

    cols          : list
                    The list of column names of the CSV files that should be exported. Passing
                    a blank list will export all the columns.

    stacked       : boolean
                    Whether to stack all bands in one file or export them as separate files.

    filename      : string
                    The name of the output GeoTIFF file. Please note that if the 'stacked'
                    argument is set to negative, the column name will be added towards the
                    end of the output file name.

    dtype         : string
                    The data type of the output raster. This is same as the options in the
                    pyrsgis.raster.export module. Options are: 'byte', 'cfloat32',
                    'cfloat64', 'cint16', 'cint32', 'float', 'float32', 'float64', 'int',
                    'int16', 'int32', 'uint8', 'uint16', 'uint32'.

    compress      : string
                    Compression type of the raster. This is same as the pyrsgis.raster.export
                    function. Options are 'LZW', 'DEFLATE' and other options that GDAL offers.

    nodata        : signed number
                    Value to treat as NoData in the out out raster.

    Examples
    --------
    Let's assume that you convert a GeoTIFF file to CSV and perform some statistical analysis.
    
    >>> from pyrsgis import convert
    >>> input_file = r'E:/path_to_your_file/raster_file.tif'
    >>> out_csvfile = input_file.replace('.tif', '.csv')
    >>> convert.raster_to_csv(input_file, filename=out_csvfile, negative=False)

    ...create new column(s) (eg. clustering classes, predictions from a stats/ML model). And then
    convert the CSV to TIF file.

    >>> new_csvfile = r'E:/path_to_your_file/predicted_file.tif'
    >>> out_tiffile = new_csvfile.replace('.csv', '.tif')
    >>> convert.csv_to_raster(new_csvfile, ref_raster=input_file, filename=out_tiffile, compress='DEFLATE')

    This will export a GeoTIFF file. If there are multiple columns in the CSV file, the arrays
    will be stacked and exported as multispectral file. One can explicitly selct the columns to
    be exported but you should know the name of the columns beforehand.

    >>> convert.csv_to_raster(new_csvfile, ref_raster=input_file, filename=out_tiffile,
                              cols=['Blue', 'Green', 'KMeans', 'RF_Class'], compress='DEFLATE')

    If you want to export each of the columns as separate bands, set the ``stacked`` parameter to
    ``False``.

    >>> convert.csv_to_raster(new_csvfile, ref_raster=input_file, filename=out_tiffile,
                              cols=['Blue', 'Green', 'KMeans', 'RF_Class'], stacked=False, compress='DEFLATE')

    """

    if filename == None:
        filename = csvfile.replace('.csv', '.tif')

    ds, _ = read(ref_raster, bands=1)
    _ = None
    x_size, y_size = ds.RasterYSize, ds.RasterXSize

    data_df = pd.read_csv(csvfile)
    n_cols = data_df.columns

    if len(cols) == 0:
        cols = data_df.columns

    out_arr = np.zeros((len(cols), x_size, y_size))
    for n, col in enumerate(cols):
        out_arr[n, :, :] = np.reshape(data_df[col].values, (x_size, y_size))
    data_df = None

    # add extension in the filename if missing
    if filename.endswith('.tif') == False:
        filename = filename + '.tif'
        
    if stacked == True:
        export(out_arr, ds, filename, dtype=dtype, compress=compress, nodata=nodata)
    else:
        for n, col in enumerate(cols):
            export(out_arr[n,:,:], ds, filename.replace('.tif', '_%s.tif'%(col)),
                   dtype=dtype, compress=compress, nodata=nodata)

"""
def pandas_to_raster(data_df, x_col, y_col, ref_raster, filename='pyrsgis_pandastoraster.tif',
                     columns=None, x_range=None, y_range=None, dtype='int', compress='default',
                     nodata=-9999):
    # get minimum and maximum value for x
    if x_range == None:
        x_min = data_df[x_col].min()
        x_max = data_df[x_col].max()
    else:
        try:
            x_min, x_max = x_range
        except:
            print('Please provide a list containing range for "x_range" parameter.')
                
    if y_range == None:
        y_min = data_df[y_col].min()
        y_max = data_df[y_col].max()
    else:
        try:
            y_min, y_max = y_range
        except:
            print('Please provide a list containing range for "y_range" parameter.')

    # normalise and scale the x and y columns
    data_df[x_col] = data_df[x_col] - x_min
    data_df[y_col] = data_df[y_col] - y_min

    # generate raster to export
    ds, _ = raster.read(ref_raster)
    _ = None
    data_arr = np.zeros((data_df.shape[1] - 2, ds.RasterXSize, ds.RasterYSize))

    if columns == None:
        columns = list(df.keys())
        for col in [x_col, y_col]:
            columns.remove(col)

    for x_idx in data_df[x_col].values:
        for y_idx in data_df[y_col].values:
            for n, item in enumerate(columns):
                data_arr[n, x_id, y_idx] = data_df[item]

    # export the raster
    raster.export(data_arr, ds, filename, dtype=dtype, compress=compress, nodata=nodata)
        
"""
