import warnings
from typing import Optional

import numpy as np
import pandas as pd
from anndata import AnnData

from .. import logging as logg
from .._settings import settings, Verbosity
from .._utils import sanitize_anndata
from .._compat import Literal
from ._utils import _get_mean_var
from ._distributed import materialize_as_ndarray
from ._simple import filter_features


def _highly_variable_features_single_batch(
    adata: AnnData,
    min_disp: Optional[float] = None,
    max_disp: Optional[float] = None,
    min_mean: Optional[float] = None,
    max_mean: Optional[float] = None,
    n_top_features: Optional[int] = None,
    n_bins: int = 20,
    flavor: Literal['seurat', 'cell_ranger'] = 'seurat',
) -> pd.DataFrame:
    """\
    See `highly_variable_features`.
    Returns
    -------
    A DataFrame that contains the columns
    `highly_variable`, `means`, `dispersions`, and `dispersions_norm`.
    """

    if min_disp is None: min_disp = 0.5
    if min_mean is None: min_mean = 0.0125
    if max_mean is None: max_mean = 3
    if max_disp is None: max_disp = np.inf

    X = adata.X
    if flavor == 'seurat':
        if 'log1p' in adata.uns_keys() and adata.uns['log1p']['base'] is not None:
            X *= np.log(adata.uns['log1p']['base'])
        X = np.expm1(X)

    mean, var = materialize_as_ndarray(_get_mean_var(X))
    # now actually compute the dispersion
    mean[mean == 0] = 1e-12  # set entries equal to zero to small value
    dispersion = var / mean
    if flavor == 'seurat':  # logarithmized mean as in Seurat
        dispersion[dispersion == 0] = np.nan
        dispersion = np.log(dispersion)
        mean = np.log1p(mean)
    # all of the following quantities are "per-feature" here
    df = pd.DataFrame()
    df['means'] = mean
    df['dispersions'] = dispersion
    if flavor == 'seurat':
        df['mean_bin'] = pd.cut(df['means'], bins=n_bins)
        disp_grouped = df.groupby('mean_bin')['dispersions']
        disp_mean_bin = disp_grouped.mean()
        disp_std_bin = disp_grouped.std(ddof=1)
        # retrieve those features that have nan std, these are the ones where
        # only a single feature fell in the bin and implicitly set them to have
        # a normalized disperion of 1
        one_feature_per_bin = disp_std_bin.isnull()
        feature_indices = np.where(one_feature_per_bin[df['mean_bin'].values])[0].tolist()
        if len(feature_indices) > 0:
            logg.debug(
                f'Feature indices {feature_indices} fell into a single bin: their '
                'normalized dispersion was set to 1.\n    '
                'Decreasing `n_bins` will likely avoid this effect.'
            )
        # Circumvent pandas 0.23 bug. Both sides of the assignment have dtype==float32,
        # but there’s still a dtype error without “.value”.
        disp_std_bin[one_feature_per_bin.values] = disp_mean_bin[one_feature_per_bin.values].values
        disp_mean_bin[one_feature_per_bin.values] = 0
        # actually do the normalization
        df['dispersions_norm'] = (
            (
                df['dispersions'].values  # use values here as index differs
                - disp_mean_bin[df['mean_bin'].values].values
            ) / disp_std_bin[df['mean_bin'].values].values
        )
    elif flavor == 'cell_ranger':
        from statsmodels import robust
        df['mean_bin'] = pd.cut(df['means'], np.r_[
            -np.inf,
            np.percentile(df['means'], np.arange(10, 105, 5)),
            np.inf
        ])
        disp_grouped = df.groupby('mean_bin')['dispersions']
        disp_median_bin = disp_grouped.median()
        # the next line raises the warning: "Mean of empty slice"
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            disp_mad_bin = disp_grouped.apply(robust.mad)
            df['dispersions_norm'] = (df['dispersions'].values
                - disp_median_bin[df['mean_bin'].values].values
                ) / disp_mad_bin[df['mean_bin'].values].values
    else:
        raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"')
    dispersion_norm = df['dispersions_norm'].values.astype('float32')
    if n_top_features is not None:
        dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)]
        dispersion_norm[::-1].sort()  # interestingly, np.argpartition is slightly slower
        if n_top_features > adata.n_vars:
            logg.info(f'`n_top_features` > `adata.n_var`, returning all features.')
            n_top_features = adata.n_vars
        disp_cut_off = dispersion_norm[n_top_features-1]
        feature_subset = np.nan_to_num(df['dispersions_norm'].values) >= disp_cut_off
        logg.debug(
            f'the {n_top_features} top features correspond to a '
            f'normalized dispersion cutoff of {disp_cut_off}'
        )
    else:
        dispersion_norm[np.isnan(dispersion_norm)] = 0  # similar to Seurat
        feature_subset = np.logical_and.reduce((
            mean > min_mean, mean < max_mean,
            dispersion_norm > min_disp,
            dispersion_norm < max_disp,
        ))

    df['highly_variable'] = feature_subset
    return df


def highly_variable_features(
    adata: AnnData,
    min_disp: Optional[float] = None,
    max_disp: Optional[float] = None,
    min_mean: Optional[float] = None,
    max_mean: Optional[float] = None,
    n_top_features: Optional[int] = None,
    n_bins: int = 20,
    flavor: Literal['seurat', 'cell_ranger'] = 'seurat',
    subset: bool = False,
    inplace: bool = True,
    batch_key: Optional[str] = None,
) -> Optional[pd.DataFrame]:
    """\
    Annotate highly variable features [Satija15]_ [Zheng17]_.
    Expects logarithmized data.
    Depending on `flavor`, this reproduces the R-implementations of Seurat
    [Satija15]_ and Cell Ranger [Zheng17]_.
    The normalized dispersion is obtained by scaling with the mean and standard
    deviation of the dispersions for features falling into a given bin for mean
    expression of features. This means that for each bin of mean expression, highly
    variable features are selected.
    Parameters
    ----------
    adata
        The annotated data matrix of shape `n_obs` × `n_vars`. Rows correspond
        to subjects and columns to features.
    min_mean
        If `n_top_features` unequals `None`, this and all other cutoffs for the means and the
        normalized dispersions are ignored. Default is 0.0125.
    max_mean
        If `n_top_features` unequals `None`, this and all other cutoffs for the means and the
        normalized dispersions are ignored. Default is 3.
    min_disp
        If `n_top_features` unequals `None`, this and all other cutoffs for the means and the
        normalized dispersions are ignored. Default is 0.5.
    max_disp
        If `n_top_features` unequals `None`, this and all other cutoffs for the means and the
        normalized dispersions are ignored. Default is `np.inf`.
    n_top_features
        Number of highly-variable features to keep.
    n_bins
        Number of bins for binning the mean feature expression. Normalization is
        done with respect to each bin. If just a single feature falls into a bin,
        the normalized dispersion is artificially set to 1. You'll be informed
        about this if you set `settings.verbosity = 4`.
    flavor
        Choose the flavor for computing normalized dispersion. In their default
        workflows, Seurat passes the cutoffs whereas Cell Ranger passes
        `n_top_features`.
    subset
        Inplace subset to highly-variable features if `True` otherwise merely indicate
        highly variable features.
    inplace
        Whether to place calculated metrics in `.var` or return them.
    batch_key
        If specified, highly-variable features are selected within each batch separately and merged.
        This simple process avoids the selection of batch-specific features and acts as a
        lightweight batch correction method.
    Returns
    -------
    Depending on `inplace` returns calculated metrics (:class:`~pandas.DataFrame`) or
    updates `.var` with the following fields
    highly_variable : bool
        boolean indicator of highly-variable features
    **means**
        means per feature
    **dispersions**
        dispersions per feature
    **dispersions_norm**
        normalized dispersions per feature
    highly_variable_nbatches : int
        If batch_key is given, this denotes in how many batches features are detected as HVG
    highly_variable_intersection : bool
        If batch_key is given, this denotes the features that are highly variable in all batches
    Notes
    -----
    This function replaces :func:`~quanp.pp.filter_features_dispersion`.
    """

    if n_top_features is not None and not all(m is None for m in [
        min_disp, max_disp, min_mean, max_mean
    ]):
        logg.info('If you pass `n_top_features`, all cutoffs are ignored.')

    if min_disp is None: min_disp = 0.5
    if min_mean is None: min_mean = 0.0125
    if max_mean is None: max_mean = 3
    if max_disp is None: max_disp = np.inf

    start = logg.info('extracting highly variable features')

    if not isinstance(adata, AnnData):
        raise ValueError(
            '`pp.highly_variable_features` expects an `AnnData` argument, '
            'pass `inplace=False` if you want to return a `pd.DataFrame`.')

    if batch_key is None:
        df = _highly_variable_features_single_batch(
            adata,
            min_disp=min_disp, max_disp=max_disp,
            min_mean=min_mean, max_mean=max_mean,
            n_top_features=n_top_features,
            n_bins=n_bins,
            flavor=flavor,
        )
    else:
        sanitize_anndata(adata)
        batches = adata.obs[batch_key].cat.categories
        df = []
        feature_list = adata.var_names
        for batch in batches:
            adata_subset = adata[adata.obs[batch_key] == batch]

            # Filter to features that are in the dataset
            with settings.verbosity.override(Verbosity.error):
                filt = filter_features(adata_subset, min_subjects=1, inplace=False)[0]

            adata_subset = adata_subset[:,filt]

            hvg = _highly_variable_features_single_batch(
                adata_subset,
                min_disp=min_disp, max_disp=max_disp,
                min_mean=min_mean, max_mean=max_mean,
                n_top_features=n_top_features,
                n_bins=n_bins,
                flavor=flavor,
            )

            # Add 0 values for features that were filtered out
            missing_hvg = pd.DataFrame(
                np.zeros((np.sum(~filt), len(hvg.columns))),
                columns=hvg.columns,
            )
            missing_hvg['highly_variable'] = missing_hvg['highly_variable'].astype(bool)
            missing_hvg['feature'] = feature_list[~filt]
            hvg['feature'] = adata_subset.var_names.values
            hvg = hvg.append(missing_hvg, ignore_index=True)

            # Order as before filtering
            idxs = np.concatenate((np.where(filt)[0], np.where(~filt)[0]))
            hvg = hvg.loc[np.argsort(idxs)]

            df.append(hvg)

        df = pd.concat(df, axis=0)
        df['highly_variable'] = df['highly_variable'].astype(int)
        df = df.groupby('feature').agg(dict(
            means=np.nanmean,
            dispersions=np.nanmean,
            dispersions_norm=np.nanmean,
            highly_variable=np.nansum,
        ))
        df.rename(columns=dict(highly_variable='highly_variable_nbatches'), inplace=True)
        df['highly_variable_intersection'] = df['highly_variable_nbatches'] == len(batches)

        if n_top_features is not None:
            # sort features by how often they selected as hvg within each batch and
            # break ties with normalized dispersion across batches
            df.sort_values(
                ['highly_variable_nbatches', 'dispersions_norm'],
                ascending=False, na_position='last', inplace=True,
            )
            df['highly_variable'] = False
            df.loc[:n_top_features, 'highly_variable'] = True
            df = df.loc[adata.var_names]
        else:
            df = df.loc[adata.var_names]
            dispersion_norm = df.dispersions_norm.values
            dispersion_norm[np.isnan(dispersion_norm)] = 0  # similar to Seurat
            feature_subset = np.logical_and.reduce((
                df.means > min_mean, df.means < max_mean,
                df.dispersions_norm > min_disp,
                df.dispersions_norm < max_disp,
            ))
            df['highly_variable'] = feature_subset

    logg.info('    finished', time=start)

    if inplace or subset:
        logg.hint(
            'added\n'
            '    \'highly_variable\', boolean vector (adata.var)\n'
            '    \'means\', float vector (adata.var)\n'
            '    \'dispersions\', float vector (adata.var)\n'
            '    \'dispersions_norm\', float vector (adata.var)'
        )
        adata.var['highly_variable'] = df['highly_variable'].values
        adata.var['means'] = df['means'].values
        adata.var['dispersions'] = df['dispersions'].values
        adata.var['dispersions_norm'] = df['dispersions_norm'].values.astype('float32', copy=False)
        if batch_key is not None:
            adata.var['highly_variable_nbatches'] = df['highly_variable_nbatches'].values
            adata.var['highly_variable_intersection'] = df['highly_variable_intersection'].values
        if subset:
            adata._inplace_subset_var(df['highly_variable'].values)
    else:
        return df
