from typing import Union, Dict, List, Tuple, Callable, Any
import torch as t
import torch.nn as nn
import numpy as np

from machin.model.nets.base import NeuralNetworkModule
from machin.frame.buffers.buffer import Transition, Buffer
from .base import TorchFramework
from .utils import safe_call


class A2C(TorchFramework):
    """
    A2C framework.
    """

    _is_top = ["actor", "critic"]
    _is_restorable = ["actor", "critic"]

    def __init__(self,
                 actor: Union[NeuralNetworkModule, nn.Module],
                 critic: Union[NeuralNetworkModule, nn.Module],
                 optimizer: Callable,
                 criterion: Callable,
                 *_,
                 lr_scheduler: Callable = None,
                 lr_scheduler_args: Tuple[Tuple, Tuple] = None,
                 lr_scheduler_kwargs: Tuple[Dict, Dict] = None,
                 actor_learning_rate: float = 0.001,
                 critic_learning_rate: float = 0.001,
                 critic_update_times: int = 50,
                 entropy_weight: float = None,
                 value_weight: float = 0.5,
                 gradient_max: float = np.inf,
                 gae_lambda: float = 1.0,
                 discount: float = 0.99,
                 replay_size: int = 500000,
                 replay_device: Union[str, t.device] = "cpu",
                 replay_buffer: Buffer = None,
                 visualize: bool = False,
                 visualize_dir: str = "",
                 **__):
        """
        Important:
            when given a state, and an optional, action actor must
            at least return two values:

            **1. Action**

              For **contiguous environments**, action must be of shape
              ``[batch_size, action_dim]`` and *clamped by action space*.
              For **discrete environments**, action could be of shape
              ``[batch_size, action_dim]`` if it is a one hot vector, or
              ``[batch_size, 1]`` if it is a categorically encoded integer.

            **2. Log likelihood of action (action probability)**

              For either type of environment, log likelihood is of shape
              ``[batch_size, 1]``.

              Action probability must be differentiable, Gradient of actor
              is calculated from the gradient of action probability.

            The third entropy value is optional:

            **3. Entropy of action distribution**

              Entropy is usually calculated using dist.entropy(), its shape
              is ``[batch_size, 1]``. You must specify ``entropy_weight``
              to make it effective.

        Hint:
            For contiguous environments, action's are not directly output by
            your actor, otherwise it would be rather inconvenient to calculate
            the log probability of action. Instead, your actor network should
            output parameters for a certain distribution
            (eg: :class:`~torch.distributions.categorical.Normal`)
            and then draw action from it.

            For discrete environments,
            :class:`~torch.distributions.categorical.Categorical` is sufficient,
            since differentiable ``rsample()`` is not needed.

            This trick is also known as **reparameterization**.

        Hint:
            Actions are from samples during training in the actor critic
            family (A2C, A3C, PPO, TRPO, IMPALA).

            When your actor model is given a batch of actions and states, it
            must evaluate the states, and return the log likelihood of the
            given actions instead of re-sampling actions.

            An example of your actor in contiguous environments::

                class ActorNet(nn.Module):
                    def __init__(self):
                        super(ActorNet, self).__init__()
                        self.fc = nn.Linear(3, 100)
                        self.mu_head = nn.Linear(100, 1)
                        self.sigma_head = nn.Linear(100, 1)

                    def forward(self, state, action=None):
                        x = t.relu(self.fc(state))
                        mu = 2.0 * t.tanh(self.mu_head(x))
                        sigma = F.softplus(self.sigma_head(x))
                        dist = Normal(mu, sigma)
                        action = (action
                                  if action is not None
                                  else dist.sample())
                        action_entropy = dist.entropy()
                        action = action.clamp(-2.0, 2.0)
                        action_log_prob = dist.log_prob(action)
                        return action, action_log_prob, action_entropy

        Hint:
            Entropy weight is usually negative, to increase exploration.

            Value weight is usually 0.5. So critic network converges less
            slowly than the actor network and learns more conditions.

            Update equation is equivalent to:

            :math:`Loss= w_e * Entropy + w_v * Loss_v + w_a * Loss_a`
            :math:`Loss_a = -log\\_likelihood * advantage`
            :math:`Loss_v = criterion(target\\_bellman\\_value - V(s))`

        Args:
            actor: Actor network module.
            critic: Critic network module.
            optimizer: Optimizer used to optimize ``actor`` and ``critic``.
            criterion: Criterion used to evaluate the value loss.
            lr_scheduler: Learning rate scheduler of ``optimizer``.
            lr_scheduler_args: Arguments of the learning rate scheduler.
            lr_scheduler_kwargs: Keyword arguments of the learning
                rate scheduler.
            actor_learning_rate: Learning rate of the actor optimizer,
                not compatible with ``lr_scheduler``.
            critic_learning_rate: Learning rate of the critic optimizer,
                not compatible with ``lr_scheduler``.
            critic_update_times: Times to update your critic model in each
                ``update()`` call.
            entropy_weight: Weight of entropy in your loss function, a positive
                entropy weight will minimize entropy, while a negative one will
                maximize entropy.
            value_weight: Weight of critic value loss.
            gradient_max: Maximum gradient.
            gae_lambda: :math:`\\lambda` used in generalized advantage
                estimation.
            discount: :math:`\\gamma` used in the bellman function.
            replay_size: Replay buffer size. Not compatible with
                ``replay_buffer``.
            replay_device: Device where the replay buffer locates on, Not
                compatible with ``replay_buffer``.
            replay_buffer: Custom replay buffer.
            visualize: Whether visualize the network flow in the first pass.
            visualize_dir: Visualized graph save directory.
        """
        self.discount = discount
        self.value_weight = value_weight
        self.entropy_weight = entropy_weight
        self.grad_max = gradient_max
        self.gae_lambda = gae_lambda
        self.critic_upd_t = critic_update_times
        self.visualize = visualize
        self.visualize_dir = visualize_dir

        self.actor = actor
        self.critic = critic
        self.actor_optim = optimizer(self.actor.parameters(),
                                     lr=actor_learning_rate)
        self.critic_optim = optimizer(self.critic.parameters(),
                                      lr=critic_learning_rate)
        self.replay_buffer = (Buffer(replay_size, replay_device)
                              if replay_buffer is None
                              else replay_buffer)

        if lr_scheduler is not None:
            if lr_scheduler_args is None:
                lr_scheduler_args = ((), ())
            if lr_scheduler_kwargs is None:
                lr_scheduler_kwargs = ({}, {})
            self.actor_lr_sch = lr_scheduler(
                self.actor_optim,
                *lr_scheduler_args[0],
                **lr_scheduler_kwargs[0],
            )
            self.critic_lr_sch = lr_scheduler(
                self.critic_optim,
                *lr_scheduler_args[1],
                **lr_scheduler_kwargs[1]
            )

        self.criterion = criterion

        super(A2C, self).__init__()

    def act(self, state: Dict[str, Any], *_, **__):
        """
        Use actor network to give a policy to the current state.

        Returns:
            Anything produced by actor.
        """
        return safe_call(self.actor, state)

    def _eval_act(self,
                  state: Dict[str, Any],
                  action: Dict[str, Any],
                  *_, **__):
        """
        Use actor network to evaluate the log-likelihood of a given
        action in the current state.

        Returns:
            Anything produced by actor.
        """
        return safe_call(self.actor, state, action)

    def _criticize(self, state: Dict[str, Any], *_, **__):
        """
        Use critic network to evaluate current value.

        Returns:
            Value of shape ``[batch_size, 1]``
        """
        return safe_call(self.critic, state)[0]

    def store_transition(self, transition: Union[Transition, Dict]):
        """
        Add a transition sample to the replay buffer.

        Not suggested, since you will have to calculate "value" and "gae"
        by yourself.
        """
        self.replay_buffer.append(transition, required_attrs=(
            "state", "action", "next_state", "reward", "value",
            "gae", "terminal"
        ))

    def store_episode(self, episode: List[Union[Transition, Dict]]):
        """
        Add a full episode of transition samples to the replay buffer.

        "value" and "gae" are automatically calculated.
        """
        episode[-1]["value"] = episode[-1]["reward"]

        # calculate value for each transition
        for i in reversed(range(1, len(episode))):
            episode[i - 1]["value"] = \
                episode[i]["value"] * self.discount + episode[i - 1]["reward"]

        # calculate advantage
        if self.gae_lambda == 1.0:
            for trans in episode:
                trans["gae"] = (trans["value"] -
                                self._criticize(trans["state"]).item())
        elif self.gae_lambda == 0.0:
            for trans in episode:
                trans["gae"] = (trans["reward"] +
                                self.discount * (1 - float(trans["terminal"]))
                                * self._criticize(trans["next_state"]).item() -
                                self._criticize(trans["state"]).item())
        else:
            last_critic_value = 0
            last_gae = 0
            for trans in reversed(episode):
                critic_value = self._criticize(trans["state"]).item()
                gae_delta = (trans["reward"] +
                             self.discount * last_critic_value
                             * (1 - float(trans["terminal"])) -
                             critic_value)
                last_critic_value = critic_value
                last_gae = trans["gae"] = (last_gae * self.discount
                                           * (1 - float(trans["terminal"]))
                                           * self.gae_lambda +
                                           gae_delta)

        for trans in episode:
            self.replay_buffer.append(trans, required_attrs=(
                "state", "action", "next_state", "reward", "value",
                "gae", "terminal"
            ))

    def update(self,
               update_value=True,
               update_policy=True,
               concatenate_samples=True,
               **__):
        """
        Update network weights by sampling from buffer. Buffer
        will be cleared after update is finished.

        Args:
            update_value: Whether update the Q network.
            update_policy: Whether update the actor network.
            concatenate_samples: Whether concatenate the samples.

        Returns:
            mean value of estimated policy value, value loss
        """
        # sample a batch
        batch_size, (state, action, reward, next_state,
                     terminal, target_value, advantage) = \
            self.replay_buffer.sample_batch(-1,
                                            sample_method="all",
                                            concatenate=concatenate_samples,
                                            sample_attrs=[
                                                "state", "action", "reward",
                                                "next_state", "terminal",
                                                "value", "gae"],
                                            additional_concat_attrs=[
                                                "value", "gae"
                                            ])

        # normalize advantage
        advantage = ((advantage - advantage.mean()) /
                     (advantage.std() + 1e-6))

        if self.entropy_weight is not None:
            __, action_log_prob, new_action_entropy, *_ = \
                self._eval_act(state, action)
        else:
            __, action_log_prob, *_ = \
                self._eval_act(state, action)
            new_action_entropy = None

        action_log_prob = action_log_prob.view(batch_size, 1)

        # calculate policy loss
        act_policy_loss = -(action_log_prob *
                            advantage.to(action_log_prob.device))

        if new_action_entropy is not None:
            act_policy_loss += (self.entropy_weight *
                                new_action_entropy.mean())

        act_policy_loss = act_policy_loss.mean()

        if self.visualize:
            self.visualize_model(act_policy_loss, "actor",
                                 self.visualize_dir)

        # Update actor network
        if update_policy:
            self.actor.zero_grad()
            act_policy_loss.backward()
            nn.utils.clip_grad_norm_(
                self.actor.parameters(), self.grad_max
            )
            self.actor_optim.step()

        sum_value_loss = 0
        for _ in range(self.critic_upd_t):
            # calculate value loss
            value = self._criticize(state)
            value_loss = (self.criterion(target_value.to(value.device),
                                         value) *
                          self.value_weight)
            sum_value_loss += value_loss.item()

            if self.visualize:
                self.visualize_model(value_loss, "critic",
                                     self.visualize_dir)

            # Update critic network
            if update_value:
                self.critic.zero_grad()
                value_loss.backward()
                nn.utils.clip_grad_norm_(
                    self.critic.parameters(), self.grad_max
                )
                self.critic_optim.step()

        self.replay_buffer.clear()
        return -act_policy_loss.item(), sum_value_loss / self.critic_upd_t

    def update_lr_scheduler(self):
        """
        Update learning rate schedulers.
        """
        if hasattr(self, "actor_lr_sch"):
            self.actor_lr_sch.step()
        if hasattr(self, "critic_lr_sch"):
            self.critic_lr_sch.step()
