# -*- coding: utf-8 -*-
"""
Created on Tue Feb 16 16:40:16 2021

@author: chris.kerklaan


ThreediRasterGroup class checks for 3Di properties, alignment etc.

Input:
    1. Raster dictionary with standardized names
    2. Optional panden

Checks:
    1. Check alignemnt --> ALready included in rastergroups
    2. Check 3Di properties  --> Done!

Functions
    1. Convert input based on csv conversion tables
    2. Create based interception raster from panden

TODO:
    1. Correct function
    2. Write function

Notes:
    1. Memory loading speeds up raster analysis, however is costly in memory
       Not all rasters are loaded into memory, e.g., it is not usefull to load
       a dem into memory since it is not used in conversion.
    2. Maximum memory is three rasters with 500.000.000 pixels each.
    3.

Ideas:
    1. Rasterfixer




"""
# First-party imports
import csv
import shutil
import pathlib
import logging
from pathlib import Path

# Third-party imports
import numpy as np
from osgeo import gdal

# local imports
from threedi_raster_edits.gis.rastergroup import RasterGroup
from threedi_raster_edits.gis.raster import Raster
from threedi_raster_edits.gis.vector import Vector
from threedi_raster_edits.utils.project import Progress

# GLOBALS
# Logger
logger = logging.getLogger(__name__)

# CSV path
FILE_PATH = str(pathlib.Path(__file__).parent.absolute()) + "/data/"
CSV_LANDUSE_PATH = FILE_PATH + "Conversietabel_landgebruik_2020.csv"
CSV_SOIL_PATH = FILE_PATH + "Conversietabel_bodem.csv"

# output paths
OUTPUT_NAMES = {
    "friction": "friction.tif",
    "dem": "dem.tif",
    "infiltration": "infiltration.tif",
    "interception": "interception.tif",
    "intial_waterlevel": "initial_waterlevel.tif",
}
THREEDI_RASTER_NAMES = {
    "dem_file": "dem",
    "interception_file": "interception",
    "frict_coef_file": "friction",
    "infiltration_file": "infiltration",
    "initial_waterlevel_file": "initial_waterlevel",
}


class ThreediRasterGroup(RasterGroup):
    def __init__(
        self,
        dem_file: str,
        landuse_file: str = None,
        soil_file: str = None,
        interception_file: str = None,
        frict_coef_file: str = None,
        infiltration_file: str = None,
        initial_waterlevel_file: str = None,
        buildings: Vector = None,
        nodata_value=-9999,
        data_type=gdal.GDT_Float32,
        np_data_type="f4",
    ):
        """An object to edit generate, check and correct threedi rasters"""

        rasters = [Raster(dem_file, name="dem")]
        self.original_names = {"dem": Path(dem_file).stem}

        if landuse_file:
            logger.info("Loading landuse")
            landuse = Raster(landuse_file, name="landuse")
            landuse.load_to_memory()
            rasters.append(landuse)
            self.original_names["landuse"] = Path(landuse_file).stem

        if soil_file:
            logger.info("Loading soil")
            soil = Raster(soil_file, name="soil")
            soil.load_to_memory()
            rasters.append(soil)
            self.original_names["soil"] = Path(soil_file).stem

        if interception_file:
            logger.info("Loading interception")
            interception = Raster(interception_file, name="interception")
            rasters.append(interception)
            self.original_names["interception"] = Path(interception_file).stem

        if frict_coef_file:
            logger.info("Loading friction")
            friction = Raster(frict_coef_file, name="friction")
            rasters.append(friction)
            self.original_names["friction"] = Path(frict_coef_file).stem

        if infiltration_file:
            logger.info("Loading infiltration")
            infiltration = Raster(infiltration_file, name="infiltration")
            rasters.append(infiltration)
            self.original_names["infiltration"] = Path(infiltration_file).stem

        if initial_waterlevel_file:
            logger.info("Loading initial waterlevel")
            ini_wl = Raster(initial_waterlevel_file, name="initial_waterlevel")
            rasters.append(ini_wl)
            self.original_names["initial_waterlevel"] = Path(
                initial_waterlevel_file
            ).stem

        RasterGroup.__init__(self, rasters)

        if buildings:
            logger.debug("Setting buildings")
            self.buildings = buildings

        self.epsg = self.dem.epsg
        self.data_type = data_type
        self.no_data_type = np_data_type
        self.nodata_value = nodata_value

        self.retrieve_soil_conversion_table = retrieve_soil_conversion_table
        self.retrieve_landuse_conversion_table = retrieve_landuse_conversion_table

    def add(self, name, path):
        if not (name in THREEDI_RASTER_NAMES.values() or name in THREEDI_RASTER_NAMES):
            raise ValueError("name should be object or threedi compatibale")

        if name in THREEDI_RASTER_NAMES:
            name = THREEDI_RASTER_NAMES[name]
        self[name] = Raster(path)

    def check_table(self, table="soil"):
        logger.info("Checking tables")
        if table == "soil":
            if not hasattr(self, "ct_soil"):
                raise AttributeError(
                    """
                                     Please load soil csv using
                                     'load_soil_conversion_table'"""
                )
        elif table == "landuse":
            if not hasattr(self, "ct_lu"):
                raise AttributeError(
                    """
                                     Please load landuse csv using
                                     'load_landuse_conversion_table'"""
                )

    def check_properties(
        self, min_allow=-1000, max_allow=1000, max_pixel_allow=1000000000
    ):
        """
        Checks the 3Di properties of the rasters in this object
        params:
            min_allow: minimum allow value of the raster

        """
        logger.info("Checking properties")
        return check_properties(
            self.rasters,
            nodata=self.nodata,
            projection=self.epsg,
            data_type=self.data_type,
            min_allow=min_allow,
            max_allow=max_allow,
            max_pixel_allow=max_pixel_allow,
        )

    def null_raster(self):
        logger.info("Creating null raster")
        copy = self.dem.copy(shell=True)
        null_array = np.zeros((int(copy.rows), int(copy.columns)))
        null_array[~self.dem.mask] = np.nan
        copy.array = null_array
        return copy

    def load_soil_conversion_table(self, csv_soil_path=CSV_SOIL_PATH):
        logger.info("Loading soil conversion table")
        self.ct_soil, self.ct_soil_info = load_csv_file(csv_soil_path, "soil")

    def load_landuse_conversion_table(self, csv_lu_path=CSV_LANDUSE_PATH):
        logger.info("Loading landuse conversion table")
        self.ct_lu, self.ct_lu_info = load_csv_file(csv_lu_path, "landuse")

    def generate_friction(self):
        logger.info("Generating friction")
        self.friction = classify(self.landuse, "Friction", self.ct_lu)
        self.rasters.append(self.friction)

    def generate_permeability(self):
        logger.info("Generating permeability")
        self.permeability = classify(self.landuse, "Permeability", self.ct_lu)
        self.rasters.append(self.permeability)

    def generate_interception(self):
        logger.info("Generating interception")
        self.interception = classify(self.landuse, "Interception", self.ct_lu)
        self.rasters.append(self.interception)

    def generate_crop_type(self):
        logger.info("Generating crop type")
        self.crop_type = classify(self.landuse, "Crop_type", self.ct_lu)
        self.rasters.append(self.crop_type)

    def generate_max_infiltration(self):
        logger.info("Generating max infiltration")
        self.max_infiltration = classify(
            self.soil, "Max_infiltration_rate", self.ct_soil
        )
        self.rasters.append(self.max_infiltration)

    def generate_infiltration(self):
        logger.info("Generating infiltration")
        self.generate_permeability()
        self.generate_max_infiltration()

        output = self.dem.copy(shell=True)
        pbar = Progress(output.__len__(), "Generating infiltration")
        for perm_tile, inf_tile in zip(self.permeability, self.max_infiltration):
            perm_array = perm_tile.array
            inf_array = inf_tile.array

            infiltration_array = np.where(
                np.logical_and(
                    perm_array != self.nodata_value,
                    inf_array != self.nodata_value,
                ),
                perm_array * inf_array,
                self.nodata_value,
            ).astype(self.np_data_type)
            output.array = infiltration_array, *perm_tile.location
            pbar.update(quiet=False)

        self.infiltration = output
        self.infiltration.name = "Infiltration"
        self.rasters.append(self.infiltration)

    def generate_hydraulic_conductivity(self):
        logger.info("Generating hydraulic conductivity")
        self.hydraulic_conductivity = classify(
            self.soil, "Hydraulic_conductivity", self.ct_soil
        )
        self.rasters.append(self.hydraulic_conductivity)

    def generate_building_interception(self, value):
        logger.info("Generating building interception")
        null = self.null_raster()
        value_raster = null.push_vector(self.buildings, value=value)
        self.interception = value_raster.align(
            self.dem, nodata_align=True, fill_value=0
        )
        self.interception.name = "Interception"
        self.rasters.append(self.interception)

    def correct(self, nodata=True, projection=True, data_type=True):
        """corrects the rasters based on:
        - self.nodata_value
        - self.epsg
        - self.data_type
        - self.dem data/nodata
        Hence, it corrects the nodata_value, epsg, data_type and data/nodata.
        """
        logger.info("correcting properties of threedigroup rasters")
        checks = self.check_properties()

        nodata_value = self.nodata_value
        spatial_reference = self.epsg
        data_type = self.data_type

        for error in checks["errors"]:
            error_type, raster_name, _ = error
            if error_type == "nodata_value" and nodata:
                logger.debug(f"replacing nodata of {raster_name} to {nodata_value}")
                raster = getattr(self, raster_name)
                raster.replace_nodata(nodata_value)
                setattr(self, raster_name, raster)

            elif error_type == "projection" and projection:
                logger.debug(
                    f"replacing projection of {raster_name} to {spatial_reference}"
                )
                raster = getattr(self, raster_name)
                raster.spatial_reference = spatial_reference
                setattr(self, raster_name, raster)

            elif error_type == "data_type" and data_type:
                logger.debug(f"replacing data type of {raster_name} to {data_type}")
                raster = getattr(self, raster_name)
                raster.change_data_type(data_type)
                setattr(self, raster_name, raster)

        for raster in self:
            if not raster.name == "dem":
                logger.debug(f"Aligning {raster_name} to dem if needed")
                print("aligning", raster.name)
                raster = getattr(self, raster_name)
                aligned = raster.align(self.dem, idw=True, nodata_align=True)
                setattr(self, raster_name, aligned)

    def write(self, folder_path, output_names: dict = OUTPUT_NAMES):
        """

        writes raster if present in output names and in object
        creates a folder if it has to...
         params:
            folder_path: path in which the rasters are written
            output_names: e.g. {
            "friction": "rasters/friction.tif",
            "dem":"rasters/dem.tif",
            "infiltration": "rasters/infiltation.tif",
            "interception": "rasters/interception.tif",
            "intial_waterlevel": "rasters/initial_waterlevel.tif"
            }
        """

        # make folder path
        p = pathlib.Path(folder_path)
        p.mkdir(parents=True, exist_ok=True)

        # make output name folder
        p = pathlib.Path(
            folder_path + "/" + output_names[list(output_names.keys())[0]]
        ).parents[0]
        p.mkdir(parents=True, exist_ok=True)

        for table_name, table_path in output_names.items():

            if table_name in THREEDI_RASTER_NAMES:
                table_name = THREEDI_RASTER_NAMES[table_name]

            if hasattr(self, table_name):
                table = getattr(self, table_name)
                table.write(folder_path + "/" + table_path)


def retrieve_soil_conversion_table(output_path):
    logger.info("retrieving soil conversion table")
    shutil.copyfile(CSV_SOIL_PATH, output_path)


def retrieve_landuse_conversion_table(output_path):
    logger.info("retrieving landuse conversion table")
    shutil.copyfile(CSV_LANDUSE_PATH, output_path)


def check_properties(
    raster_list,
    nodata=-9999,
    projection=28992,
    max_pixel_allow=1000000000,
    data_type=gdal.GDT_Float32,
    unit="metre",
    min_allow=-1000,
    max_allow=1000,
):
    # Has the raster the nodata value of -9999?
    output = {
        "nodata": {},
        "unit": {},
        "projection": {},
        "data_type": {},
        "resolution": {},
        "square_pixels": {},
        "min_max": {},
        "total_pixels": 0,
        "errors": [],
    }

    total_pixels = 0
    for raster in raster_list:

        # nodata value check
        if raster.nodata_value != nodata:
            msg = (
                "nodata_value",
                raster.name,
                f"has a nodata value of {raster.nodata_value}",
            )
            logger.debug(msg)
            output["errors"].append(msg)
        output["nodata"][raster.name] = raster.nodata_value

        # unit check
        if raster.spatial_reference.unit != unit:
            msg = ("unit", raster.name, f"has not unit {unit}")
            logger.debug(msg)
            output["errors"].append(msg)
        output["unit"][raster.name] = raster.spatial_reference.unit

        # projection check
        if raster.spatial_reference.epsg != projection:
            msg = ("projection", raster.name, f"has not epsg {projection}")
            logger.debug(msg)
            output["errors"].append(msg)
        output["projection"][raster.name] = raster.spatial_reference.epsg

        # data type check
        if raster.data_type != data_type:
            msg = (
                "data_type",
                raster.name,
                f"is not a {gdal.GetDataTypeName(data_type)}",
            )
            logger.debug(msg)
            output["errors"].append(msg)
        output["data_type"][raster.name] = gdal.GetDataTypeName(data_type)

        # square pixel check
        if abs(raster.resolution["width"]) != abs(raster.resolution["height"]):
            msg = ("width/height", raster.name, "has not a square pixel")
            output["errors"].append(msg)
        output["square_pixels"][raster.name] = raster.resolution

        # extreme value check
        _max, _min = np.nanmax(raster.array), np.nanmin(raster.array)
        if not (min_allow < _max < max_allow and min_allow < _min < max_allow):
            msg = (
                "extreme_values",
                raster.name,
                f"has extreme values < {min_allow}, > {max_allow}",
            )
            logger.debug(msg)
            output["errors"].append(msg)
        output["min_max"][raster.name] = {"min": _min, "max": _max}

        total_pixels += raster.pixels

    # max pixel allowed check
    if total_pixels > max_pixel_allow:
        msg = f"Rasters combined pixels are larger than {max_pixel_allow}"
        logger.debug(msg)
        output["errors"].append(("maximum_allowed_pixels", msg))

    output["total_pixels"] = total_pixels

    if len(output["errors"]) == 0:
        logger.debug("ThreediRasterGroup - Check properties found no problems")
    return output


def load_csv_file(csv_path, csv_type="landuse"):
    csv_data = {}
    csv_info = {}
    if csv_type == "landuse":
        csv_structure = {
            1: "description",
            2: "unit",
            3: "range",
            4: "type",
        }
        meta_list = [1, 2, 3, 4]
    elif csv_type == "soil":
        csv_structure = {
            1: "description",
            2: "source",
            3: "unit",
            4: "range",
            5: "type",
        }
        meta_list = [1, 2, 3, 4, 5]

    with open(csv_path) as csvfile:
        csv_reader = csv.reader(csvfile, delimiter=";")
        for i, line in enumerate(csv_reader):

            # headers
            if i == 0:
                headers = line
                for column_value in line:
                    csv_data[column_value] = []
                    csv_info[column_value] = {}

            # units, descriptions, ranges etc.
            elif i in meta_list:

                for column_index, column_value in enumerate(line):
                    field = csv_structure[i]
                    csv_info[headers[column_index]][field] = column_value
            # csv data
            else:
                for column_index, column_value in enumerate(line):
                    column = headers[column_index]
                    column_type = csv_info[column]["type"]

                    if column_value == "":
                        csv_data[column].append(None)
                    else:
                        if column_type == "Integer":
                            column_value = int(column_value)
                        elif column_type == "Double" or column_type == "Real":
                            column_value = float(column_value)
                        elif column_type == "String":
                            column_value = str(column_value)
                        csv_data[column].append(column_value)

        return csv_data, csv_info


def classify(raster: Raster, table: str, ct: dict):
    """input is a table to classify,
    raster is the template raster,
    ct is the conversion table
    returns a dictionary of classified rasters"""

    ct_codes = ct["Code"]
    ct_table = ct[table]
    pbar = Progress(raster.__len__(), "Classifying {}".format(table))

    output = raster.copy(shell=True)
    for tile in raster:
        array = tile.array

        output_array = classify_array(array, ct_table, ct_codes)

        if type(output_array) == None:
            output.array = array, *tile.location
        else:
            output.array = output_array, *tile.location

        # # Delete array
        del array
        del output_array

        pbar.update(False)
    output.name = table
    return output


def classify_array(array, ct_table, ct_codes):
    if np.isnan(array).all():
        return

    codes = np.unique(array[~np.isnan(array)])
    output_array = np.copy(array)
    for code in codes:
        output_array[array == code] = ct_table[ct_codes.index(code)]

    return output_array


# if NUMBA_EXISTS:

#     @njit()
#     def nb_classify_array(array, ct_table, ct_codes):
#         """ about 2 times faster"""
#         if np.isnan(array).all():
#             return
#         m, n = array.shape
#         output_array = np.copy(array)
#         for i in range(m):
#             for j in range(n):
#                 value = array[i, j]

#                 if np.isnan(value):
#                     continue

#                 output_array[i, j] = ct_table[ct_codes.index(value)]

#         return output_array
