# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_geo.ipynb (unless otherwise specified).

__all__ = ['Region', 'open_shp', 'open_tif', 'bounds_from_shapefile', 'size_from_bounds', 'size_resolution_assert',
           'rasterize', 'downsample', 'is_intersection', 'polygon_from_bounds', 'crop', 'bounds_from_coords']

# Cell
import geopandas as gp
import numpy as np
import matplotlib.pyplot as plt
import json
import numpy as np
import pandas as pd
import shapely
import rasterio
import rasterio.features
from rasterio import features
from rasterio.mask import mask
from rasterio.merge import merge
from rasterio.coords import BoundingBox
from rasterio.crs import CRS
import warnings
import pyresample.geometry as prgeo
from fastcore.test import *

from .core import *

# Cell
class Region():
    """Defines a geographical region with a name, a bounding box and the pixel size"""
    def __init__(self, name:str, bbox:list, pixel_size:float, epsg:int=4326,
                 shape=None):
        self.name       = name
        self.bbox       = rasterio.coords.BoundingBox(*bbox) # left, bottom, right, top
        self.pixel_size = pixel_size
        self.epsg       = epsg
        self._shape     = shape

    @property
    def width(self):
        "Width of the region"
        return self.shape[1]

    @property
    def height(self):
        "Height of the region"
        return self.shape[0]

    @property
    def transform(self):
        "Rasterio Affine transform of the region"
        return rasterio.transform.from_bounds(*self.bbox, self.width, self.height)

    @property
    def crs(self):
        return CRS.from_epsg(self.epsg)

    @property
    def shape(self):
        "Shape of the region (height, width)"
        crs = prgeo.CRS(f'EPSG:{self.epsg}')
        area_def = prgeo.create_area_def(
            crs.name, crs.to_dict(), area_extent=self.bbox, resolution=self.pixel_size)
        if self._shape is None:
            return area_def.shape
        else: return self._shape

    def coords(self, offset='ul'):
        "Computes longitude and latitude arrays given a shape and a rasterio Affine transform"
        rxy = rasterio.transform.xy
        ys, xs = map(range, self.shape)
        return (np.array(rxy(self.transform, [0]*len(xs), xs, offset=offset)[0]),
                np.array(rxy(self.transform, ys, [0]*len(ys), offset=offset)[1]))

    @classmethod
    def load(cls, file):
        "Loads region information from json file"
        with open(file, 'r') as f:
            args = json.load(f)
        return cls(args['name'], args['bbox'], args['pixel_size'])

    def new(self, name:str=None, bbox:list=None, pixel_size:float=None, epsg:int=None,
            shape:tuple=None):
        "Create new region with updated parameters."
        if name is None: name = self.name
        if bbox is None: bbox = list(self.bbox)
        if pixel_size is None: pixel_size = self.pixel_size
        if epsg is None: epsg = self.epsg
        return Region(name, bbox, pixel_size, epsg=epsg, shape=shape)

    def export(self, file):
        """Exports region information to json file"""
        dict2json(self.__dict__, file)

    def __repr__(self):
        return '\n'.join([f'{i}: {o}' for i, o in self.__dict__.items()]) + '\n'

# Cell
def open_shp(file):
    "Read shapefile"
    return gp.read_file(file)

def open_tif(file):
    "Read tiff"
    return rasterio.open(file)

def bounds_from_shapefile(shapefile):
    "Computes bounding box for shapefile"
    bounds = shapefile.bounds
    return bounds.minx.min(), bounds.miny.min(), bounds.maxx.max(), bounds.maxy.max()

def size_from_bounds(bounds, resolution):
    "Computes width and height from bounds for a given pixel resolution"
    mlon = np.mean([bounds[2], bounds[0]])
    width = np.ceil((bounds[2]-bounds[0])*(111100/resolution)*np.cos(np.deg2rad(mlon))).astype(int)
    height = np.ceil((bounds[3]-bounds[1])*(111100/resolution)).astype(int)
    return width, height

def size_resolution_assert(size, resolution):
    if size is None and resolution is None:
        raise Exception('You must define either size or resolution')
    if size is not None and resolution is not None:
        warnings.warn('resolution not used, computed based on size and bounds')

def rasterize(x, value_key=None, region=None, merge_alg='replace'):
    "Rasterize shapefile"
    if merge_alg == 'replace':
        merge_alg = rasterio.enums.MergeAlg.replace
    elif merge_alg == 'add':
        merge_alg = rasterio.enums.MergeAlg.add
    values = [1]*len(x) if value_key is None else x[value_key]
    shapes = (v for v in zip(x.geometry, values))
    return rasterio.features.rasterize(shapes, out_shape=region.shape,
            transform=region.transform, merge_alg=merge_alg)

def downsample(x, src_tfm=None, dst_tfm=None, dst_shape=None,
               src_crs={'init': 'EPSG:4326'}, dst_crs={'init': 'EPSG:4326'},
               resampling='average'):
    "Donwsample a numpy array x"
    if resampling == 'average':
        resampling = rasterio.warp.Resampling.average
    elif resampling == 'bilinear':
        resampling = rasterio.warp.Resampling.bilinear
    elif resampling == 'nearest':
        resampling = rasterio.warp.Resampling.nearest
    out = np.zeros(dst_shape)
    rasterio.warp.reproject(x, out, src_transform=src_tfm, dst_transform=dst_tfm,
                            src_crs=src_crs, dst_crs=dst_crs, resampling=resampling)
    return out

def is_intersection(gdf1, gdf2):
    "Find the intersection between two geo pandas dataframes"
    return len(gp.overlay(gdf1, gdf2, how='intersection')) > 0

def polygon_from_bounds(bounds, to_GeoDataFrame=False, crs={'init': 'EPSG:4326'}):
    "Create a polygon object from bounds"
    b_ind = [[0,1],[2,1],[2,3],[0,3]]
    shape = shapely.geometry.Polygon([(bounds[i],bounds[j]) for i, j in b_ind])
    if to_GeoDataFrame: shape = gp.GeoDataFrame(crs=crs, geometry=[shape])
    return shape

def crop(x, bounds=None, shape=None, crop=True):
    """
    Crop rasterio dataset for a region defined by bounds.
        x is a dataset or a list of datasets (rasterio.open).
        If list then merge with bounds is used.
        else mask is used to crop given bounds or any given shape.
    """
    if len(x) == 1 and isinstance(x, list):
        x = x[0]
    if isinstance(x, list):
        out, transform = merge(x, bounds)
    else:
        if bounds is not None: shape = polygon_from_bounds(bounds)
        out, transform = mask(x, shapes=[shape], crop=crop)
    return out.squeeze(), transform

def bounds_from_coords(lon, lat):
    "Compute bounds list form lon lat coords"
    return lon.min(), lat.min(), lon.max(), lat.max()