# -*- coding: utf-8 -*-

""" Module summary description.

More detailed description.
"""

import warnings

import numpy as np

from itertools import chain, islice

from sklearn.decomposition import PCA

from fototex import R_SPECTRA_NO_DATA_VALUE
from fototex._numba import sector_average, azimuthal_average, contains
from fototex.exceptions import ImportSklearnWarning

try:
    import sklearn.decomposition
except ModuleNotFoundError:
    SKLEARN_SUPPORT = False
    warnings.warn("No scikit-learn module found, PCA support will be internal", ImportSklearnWarning)
else:
    SKLEARN_SUPPORT = True


def degrees_to_cardinal(d):
    """ Convert degrees to cardinal direction

    Thanks to https://gist.github.com/RobertSudwarts/acf8df23a16afdb5837f
    :param d:
    :return:
    """
    dirs = ["N", "NNE", "NE", "ENE", "E", "ESE", "SE", "SSE",
            "S", "SSW", "SW", "WSW", "W", "WNW", "NW", "NNW"]
    ix = int((d + 11.25)/22.5)
    return dirs[ix % 16]


def get_slice_along_axis(ndim, axis, _slice):
    """ Used to make indexing with any n-dimensional numpy array

    :param ndim: number of dimensions
    :param axis: axis for which we want the slice
    :param _slice: the required slice
    :return:
    """
    slc = [slice(None)] * ndim
    slc[axis] = _slice

    return tuple(slc)


def get_power_spectrum_density(window, standardize):
    """ Compute power spectrum density for given window

    :param window:
    :param standardize:
    :return:
    """
    # Fast Fourier Transform (FFT) in 2 dims, center fft and then calculate 2D power spectrum density
    ft = np.fft.fft2(window, norm="ortho")
    ft = np.fft.fftshift(ft)
    psd = np.abs(ft) ** 2

    if standardize:
        psd = psd / np.var(window)

    return psd


def pca(data, n_components, sklearn_pca=SKLEARN_SUPPORT):
    """ Principal component analysis

    :param data:
    :param n_components: number of dimensions for PCA
    :param sklearn_pca: use sklearn.decomposition.PCA class
    :return:
    """

    # replace nodata and inf values and standardize
    data = np.nan_to_num(data)
    data -= data.mean(axis=0)
    data /= data.std(axis=0)
    # data = StandardScaler().fit_transform(data)

    if sklearn_pca:
        sk_pca = PCA(n_components=n_components)
        sk_pca.fit(data)
        return sk_pca.components_.T, sk_pca.transform(data)
    else:
        # get normalized covariance matrix
        c = np.cov(data.T)
        # get the eigenvalues/eigenvectors
        eig_val, eig_vec = np.linalg.eig(c)
        # get sort index of eigenvalues in ascending order
        idx = np.argsort(eig_val)[::-1]
        # sorting the eigenvectors according to the sorted eigenvalues
        eig_vec = eig_vec[:, idx]
        # cutting some PCs if needed
        eig_vec = eig_vec[:, :n_components]
        # return projection of the data in the new space
        return eig_vec, np.dot(data, eig_vec)


def rspectrum(window, radius, window_size, nb_sample, standardize, keep_dc_component, no_data_value):
    """ Compute r-spectrum for given window and filter smaller ones and the ones with no data

    Calculate the azimuthally averaged 1D power
    spectrum (also called radial spectrum, i.e. r-spectrum)
    :param window: window array
    :param radius: corresponding radius integer
    :param window_size: window typical size
    :param nb_sample: nb of frequencies to sample within window
    :param standardize: standardize r-spectrum by window variance
    :param keep_dc_component: (bool) either keep DC component of the FFT (0 frequency part of the signal) or not. It
    may substantially change the results, so use it carefully.
    :param no_data_value: (int, float) value corresponding to no data
    :return:
    """
    if window.shape[0] == window.shape[1] == window_size and not contains(window, no_data_value):
        if keep_dc_component:
            return azimuthal_average(radius, get_power_spectrum_density(window, standardize))[0:nb_sample]
        else:
            return azimuthal_average(radius, get_power_spectrum_density(window, standardize))[1:nb_sample + 1]
    else:
        return np.full(nb_sample, R_SPECTRA_NO_DATA_VALUE)


def rspectrum_per_sector(window, radius, sectors, window_size, nb_sample, nb_sectors, standardize, keep_dc_component,
                         no_data_value):
    """ Compute r-spectrum for each sector

    Calculate radial spectrum in specific directions, i.e. quadrants
    :param window: window array
    :param radius:
    :param sectors: result from get_sectors function (divide the circle into sectors according to nb_sectors)
    :param window_size: window typical size
    :param nb_sample: number of frequencies to sample within window
    :param nb_sectors: number of sectors
    :param standardize: standardize r-spectrum by window variance
    :param keep_dc_component:(bool) either keep DC component of the FFT (0 frequency part of the signal) or not. It
    may substantially change the results, so use it carefully.
    :param no_data_value: value corresponding to no data
    :return:
    """
    if window.shape[0] == window.shape[1] == window_size and not contains(window, no_data_value):
        if keep_dc_component:
            return sector_average(get_power_spectrum_density(window, standardize), radius, sectors,
                                  nb_sectors)[:, 0:nb_sample]
        else:
            return sector_average(get_power_spectrum_density(window, standardize), radius, sectors,
                                  nb_sectors)[:, 1:nb_sample + 1]
    else:
        return np.full((nb_sectors, nb_sample), R_SPECTRA_NO_DATA_VALUE)


def split_into_chunks(iterable, size=10):
    """

    :param iterable:
    :param size: size of each chunk
    :return:
    """
    iterator = iter(iterable)
    for first in iterator:
        yield chain([first], islice(iterator, size - 1))


def standard_deviation(nb_data, sum_of_values, sum_of_square_values):
    """ Compute standard deviation based on variance formula

    :param nb_data: (int) number of data
    :param sum_of_values:
    :param sum_of_square_values:
    :return:
    """
    return np.sqrt(sum_of_square_values / nb_data - (sum_of_values / nb_data) ** 2)
