# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ========================================================================
#
# ADAPTED FROM Google's disentanglement_lib:
# https://github.com/google-research/disentanglement_lib
#
# (Major) Modifications for pytorch and disent by Nathan Michlo

"""
Visualization module for disentangled representations.
"""

import logging
import numbers

import numpy as np
import torch

from disent.util import to_numpy
from disent.visualize import visualize_util
from disent.visualize.visualize_util import make_animation_grid
from disent.visualize.visualize_util import reconstructions_to_images


log = logging.getLogger(__name__)


# ========================================================================= #
# Visualise varying single factor for model                                 #
# ========================================================================= #


def latent_traversal_1d_multi_dim(
        decoder_fn,
        latent_vector,
        dimensions=None,
        values=None,
        decoder_device=None
):
    """Creates latent traversals for a latent vector along multiple dimensions.

    Creates a 2d grid image where each grid image is generated by passing a
    modified version of latent_vector to the generator_fn. In each column, a
    fixed dimension of latent_vector is modified. In each row, the value in the
    modified dimension is replaced by a fixed value.

    Args:
      decoder_fn: Function that computes (fixed size) images from latent
        representation. It should accept a single Numpy array argument of the same
        shape as latent_vector and return a Numpy array of images where the first
        dimension corresponds to the different vectors in latent_vectors.
      latent_vector: 1d Numpy array with the base latent vector to be used.
      dimensions: 1d Numpy array with the indices of the dimensions that should be
        modified. If an integer is passed, the dimensions 0, 1, ...,
        (dimensions - 1) are modified. If None is passed, all dimensions of
        latent_vector are modified.
      values: 1d Numpy array with the latent space values that should be used for
        modifications. If an integer is passed, a linear grid between -1 and 1
        with that many points is constructed. If None is passed, a default grid is
        used (whose specific design is not guaranteed).
      transpose: Boolean which indicates whether rows and columns of the 2d grid
        should be transposed.

    Returns:
      2d list of images that are outputs from the traversal.
    """
    # TODO: EXTRACT
    # Make sure everything is a numpy array
    latent_vector = to_numpy(latent_vector)

    # Defualt values
    if dimensions is None:
        dimensions = latent_vector.shape[0]  # Default case, use all available dimensions.
    if values is None:
        values = 11  # Default case, get 11 evenly spaced samples per factor

    # Handle integers instead of arrays
    if isinstance(dimensions, numbers.Integral):
        assert dimensions <= latent_vector.shape[0], "The number of dimensions of latent_vector is less than the number of dimensions requested in the arguments."
        assert dimensions >= 1, "The number of dimensions has to be at least 1."
        dimensions = np.arange(dimensions)
    if isinstance(values, numbers.Integral):
        assert values > 1, "If an int is passed for values, it has to be >1."
        values = np.linspace(-1., 1., num=values)

    # Make sure everything is a numpy array
    dimensions = to_numpy(dimensions)
    values = to_numpy(values)

    assert latent_vector.ndim == 1, "Latent vector needs to be 1-dimensional."
    assert dimensions.ndim == 1, "Dimensions vector needs to be 1-dimensional."
    assert values.ndim == 1, "Values vector needs to be 1-dimensional."

    # We iteratively generate the rows/columns for each dimension as different
    # Numpy arrays. We do not preallocate a single final Numpy array as this code
    # is not performance critical and as it reduces code complexity.
    factor_images = []
    for dimension in dimensions:
        # Creates num_values copy of the latent_vector along the first axis.
        latent_traversal_vectors = np.tile(latent_vector, [len(values), 1])
        # Intervenes in the latent space.
        latent_traversal_vectors[:, dimension] = values
        # Generate the batch of images
        latent_traversal_vectors = torch.as_tensor(latent_traversal_vectors, device=decoder_device)
        images = decoder_fn(latent_traversal_vectors)
        images = reconstructions_to_images(images)
        factor_images.append(images)

    return to_numpy(factor_images)


# ========================================================================= #
# Visualise Random Latent Samples                                           #
# ========================================================================= #


def latent_random_samples(decoder_fn, z_size, num_samples=16, decoder_device=None):
    z = torch.randn(num_samples, z_size, device=decoder_device)
    images = decoder_fn(z)
    return reconstructions_to_images(images)


# ========================================================================= #
# Visualise Latent Traversals                                               #
# ========================================================================= #


def latent_traversals(decoder_fn, z_mean, dimensions=None, values=None, decoder_device=None):
    # for each sample
    traversals = []
    for i in range(len(z_mean)):
        # TODO: add support for cycle methods from below? Is that actually useful?
        grid = latent_traversal_1d_multi_dim(decoder_fn, z_mean[i, :], dimensions=dimensions, values=values, decoder_device=decoder_device)
        traversals.append(grid)
    # return
    return to_numpy(traversals)


# ========================================================================= #
# Visualise Latent Cycles (for animations)                                  #
# ========================================================================= #


def _z_std_gaussian_cycle(base_z, z_means, z_logvars, z_idx, num_frames):
    # Cycle through quantiles of a standard Gaussian.
    zs = np.repeat(np.expand_dims(base_z, 0), num_frames, axis=0)
    zs[:, z_idx] = visualize_util.cycle_gaussian(base_z[z_idx], num_frames, loc=0, scale=1)
    return zs


def _z_fitted_gaussian_cycle(base_z, z_means, z_logvars, z_idx, num_frames):
    # Cycle through quantiles of a fitted Gaussian.
    zs = np.repeat(np.expand_dims(base_z, 0), num_frames, axis=0)
    loc = np.mean(z_means[:, z_idx])
    total_variance = np.mean(np.exp(z_logvars[:, z_idx])) + np.var(z_means[:, z_idx])
    zs[:, z_idx] = visualize_util.cycle_gaussian(base_z[z_idx], num_frames, loc=loc, scale=np.sqrt(total_variance))
    return zs


def _z_fixed_interval_cycle(base_z, z_means, z_logvars, z_idx, num_frames):
    # Cycle through [-2, 2] interval.
    zs = np.repeat(np.expand_dims(base_z, 0), num_frames, axis=0)
    zs[:, z_idx] = visualize_util.cycle_interval(base_z[z_idx], num_frames, -2., 2.)
    return zs


def _z_conf_interval_cycle(base_z, z_means, z_logvars, z_idx, num_frames):
    # Cycle linearly through +-2 std dev of a fitted Gaussian.
    zs = np.repeat(np.expand_dims(base_z, 0), num_frames, axis=0)
    loc = np.mean(z_means[:, z_idx])
    total_variance = np.mean(np.exp(z_logvars[:, z_idx])) + np.var(z_means[:, z_idx])
    scale = np.sqrt(total_variance)
    zs[:, z_idx] = visualize_util.cycle_interval(base_z[z_idx], num_frames, loc - 2. * scale, loc + 2. * scale)
    return zs


def _z_minmax_interval_cycle(base_z, z_means, z_logvars, z_idx, num_frames):
    # Cycle linearly through minmax of a fitted Gaussian.
    zs = np.repeat(np.expand_dims(base_z, 0), num_frames, axis=0)
    zs[:, z_idx] = visualize_util.cycle_interval(base_z[z_idx], num_frames, np.min(z_means[:, z_idx]), np.max(z_means[:, z_idx]))
    return zs


_LATENT_CYCLE_MODES_MAP = {
    'std_gaussian_cycle': _z_std_gaussian_cycle,
    'fitted_gaussian_cycle': _z_fitted_gaussian_cycle,
    'fixed_interval_cycle': _z_fixed_interval_cycle,
    'conf_interval_cycle': _z_conf_interval_cycle,
    'minmax_interval_cycle': _z_minmax_interval_cycle,
}

LATENT_CYCLE_MODES = list(_LATENT_CYCLE_MODES_MAP.keys())


def latent_cycle(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_animations=4, num_frames=20, decoder_device=None):
    assert len(z_means) > 1 and len(z_logvars) > 1, 'not enough samples to average'
    # convert
    z_means, z_logvars = to_numpy(z_means), to_numpy(z_logvars)
    # get mode
    if mode not in _LATENT_CYCLE_MODES_MAP:
        raise KeyError(f'Unsupported mode: {repr(mode)} not in {set(_LATENT_CYCLE_MODES_MAP)}')
    z_gen_func = _LATENT_CYCLE_MODES_MAP[mode]
    animations = []
    for i, base_z in enumerate(z_means[:num_animations]):
        frames = []
        for j in range(z_means.shape[1]):
            z = z_gen_func(base_z, z_means, z_logvars, j, num_frames)
            z = torch.as_tensor(z, device=decoder_device)
            frames.append(reconstructions_to_images(decoder_func(z)))
        animations.append(frames)
    return to_numpy(animations)


def latent_cycle_grid_animation(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_frames=21, pad=4, border=True, bg_color=0.5, decoder_device=None):
    # produce latent cycle animation & merge frames
    animation = latent_cycle(decoder_func, z_means, z_logvars, mode=mode, num_animations=1, num_frames=num_frames, decoder_device=decoder_device)
    frames = np.transpose(make_animation_grid(animation[0], pad=pad, border=border, bg_color=bg_color), [0, 3, 1, 2])
    # check and add missing channel if needed (convert greyscale to rgb images)
    assert frames.shape[1] in {1, 3}, f'Invalid number of image channels: {animation.shape} -> {frames.shape}'
    if frames.shape[1] == 1:
        frames = np.repeat(frames, 3, axis=1)
    return frames


# ========================================================================= #
# END                                                                       #
# ========================================================================= #
