from typing import Union, Dict, List, Tuple, Callable, Any
from copy import deepcopy
import torch as t
import torch.nn as nn
import numpy as np
from machin.model.nets.base import NeuralNetworkModule
from machin.frame.buffers.buffer import Buffer
from machin.frame.transition import Transition
from .base import TorchFramework, Config
from .utils import (
safe_call,
assert_and_get_valid_models,
assert_and_get_valid_optimizer,
assert_and_get_valid_criterion,
assert_and_get_valid_lr_scheduler,
)
[docs]class A2C(TorchFramework):
"""
A2C framework.
"""
_is_top = ["actor", "critic"]
_is_restorable = ["actor", "critic"]
def __init__(
self,
actor: Union[NeuralNetworkModule, nn.Module],
critic: Union[NeuralNetworkModule, nn.Module],
optimizer: Callable,
criterion: Callable,
*_,
lr_scheduler: Callable = None,
lr_scheduler_args: Tuple[Tuple, Tuple] = None,
lr_scheduler_kwargs: Tuple[Dict, Dict] = None,
batch_size: int = 100,
actor_update_times: int = 5,
critic_update_times: int = 10,
actor_learning_rate: float = 0.001,
critic_learning_rate: float = 0.001,
entropy_weight: float = None,
value_weight: float = 0.5,
gradient_max: float = np.inf,
gae_lambda: float = 1.0,
discount: float = 0.99,
normalize_advantage: bool = True,
replay_size: int = 500000,
replay_device: Union[str, t.device] = "cpu",
replay_buffer: Buffer = None,
visualize: bool = False,
visualize_dir: str = "",
**__
):
"""
Important:
When given a state, and an optional action, actor must
at least return two values:
**1. Action**
For **contiguous environments**, action must be of shape
``[batch_size, action_dim]`` and *clamped by action space*.
For **discrete environments**, action could be of shape
``[batch_size, action_dim]`` if it is a one hot vector, or
``[batch_size, 1]`` or [batch_size] if it is a categorically
encoded integer.
When the given action is not None, actor must return the given
action.
**2. Log likelihood of action (action probability)**
For either type of environment, log likelihood is of shape
``[batch_size, 1]`` or ``[batch_size]``.
Action probability must be differentiable, Gradient of actor
is calculated from the gradient of action probability.
When the given action is not None, actor must return the log
likelihood of the given action.
The third entropy value is optional:
**3. Entropy of action distribution**
Entropy is usually calculated using dist.entropy(), its shape
is ``[batch_size, 1]`` or ``[batch_size]``. You must specify
``entropy_weight`` to make it effective.
Hint:
For contiguous environments, action's are not directly output by
your actor, otherwise it would be rather inconvenient to calculate
the log probability of action. Instead, your actor network should
output parameters for a certain distribution
(eg: :class:`~torch.distributions.categorical.Normal`)
and then draw action from it.
For discrete environments,
:class:`~torch.distributions.categorical.Categorical` is sufficient,
since differentiable ``rsample()`` is not needed.
This trick is also known as **reparameterization**.
Hint:
Actions are from samples during training in the actor critic
family (A2C, A3C, PPO, TRPO, IMPALA).
When your actor model is given a batch of actions and states, it
must evaluate the states, and return the log likelihood of the
given actions instead of re-sampling actions.
An example of your actor in contiguous environments::
class ActorNet(nn.Module):
def __init__(self):
super(ActorNet, self).__init__()
self.fc = nn.Linear(3, 100)
self.mu_head = nn.Linear(100, 1)
self.sigma_head = nn.Linear(100, 1)
def forward(self, state, action=None):
x = t.relu(self.fc(state))
mu = 2.0 * t.tanh(self.mu_head(x))
sigma = F.softplus(self.sigma_head(x))
dist = Normal(mu, sigma)
action = (action
if action is not None
else dist.sample())
action_entropy = dist.entropy()
action = action.clamp(-2.0, 2.0)
# Since we are representing a multivariate gaussian
# distribution in terms of independent univariate gaussians:
action_log_prob = dist.log_prob(action).sum(
dim=1, keepdim=True
)
return action, action_log_prob, action_entropy
Hint:
Entropy weight is usually negative, to increase exploration.
Value weight is usually 0.5. So critic network converges less
slowly than the actor network and learns more conditions.
Update equation is equivalent to:
:math:`Loss= w_e * Entropy + w_v * Loss_v + w_a * Loss_a`
:math:`Loss_a = -log\\_likelihood * advantage`
:math:`Loss_v = criterion(target\\_bellman\\_value - V(s))`
Args:
actor: Actor network module.
critic: Critic network module.
optimizer: Optimizer used to optimize ``actor`` and ``critic``.
criterion: Criterion used to evaluate the value loss.
lr_scheduler: Learning rate scheduler of ``optimizer``.
lr_scheduler_args: Arguments of the learning rate scheduler.
lr_scheduler_kwargs: Keyword arguments of the learning
rate scheduler.
batch_size: Batch size used during training.
actor_update_times: Times to update actor in ``update()``.
critic_update_times: Times to update critic in ``update()``.
actor_learning_rate: Learning rate of the actor optimizer,
not compatible with ``lr_scheduler``.
critic_learning_rate: Learning rate of the critic optimizer,
not compatible with ``lr_scheduler``.
entropy_weight: Weight of entropy in your loss function, a positive
entropy weight will minimize entropy, while a negative one will
maximize entropy.
value_weight: Weight of critic value loss.
gradient_max: Maximum gradient.
gae_lambda: :math:`\\lambda` used in generalized advantage
estimation.
discount: :math:`\\gamma` used in the bellman function.
normalize_advantage: Whether to normalize sampled advantage values in
the batch.
replay_size: Replay buffer size. Not compatible with
``replay_buffer``.
replay_device: Device where the replay buffer locates on, Not
compatible with ``replay_buffer``.
replay_buffer: Custom replay buffer.
visualize: Whether visualize the network flow in the first pass.
visualize_dir: Visualized graph save directory.
"""
self.batch_size = batch_size
self.actor_update_times = actor_update_times
self.critic_update_times = critic_update_times
self.discount = discount
self.value_weight = value_weight
self.entropy_weight = entropy_weight
self.gradient_max = gradient_max
self.gae_lambda = gae_lambda
self.normalize_advantage = normalize_advantage
self.visualize = visualize
self.visualize_dir = visualize_dir
self.actor = actor
self.critic = critic
self.actor_optim = optimizer(self.actor.parameters(), lr=actor_learning_rate)
self.critic_optim = optimizer(self.critic.parameters(), lr=critic_learning_rate)
self.replay_buffer = (
Buffer(replay_size, replay_device)
if replay_buffer is None
else replay_buffer
)
if lr_scheduler is not None:
if lr_scheduler_args is None:
lr_scheduler_args = ((), ())
if lr_scheduler_kwargs is None:
lr_scheduler_kwargs = ({}, {})
self.actor_lr_sch = lr_scheduler(
self.actor_optim, *lr_scheduler_args[0], **lr_scheduler_kwargs[0],
)
self.critic_lr_sch = lr_scheduler(
self.critic_optim, *lr_scheduler_args[1], **lr_scheduler_kwargs[1]
)
self.criterion = criterion
super().__init__()
@property
def optimizers(self):
return [self.actor_optim, self.critic_optim]
@optimizers.setter
def optimizers(self, optimizers):
self.actor_optim, self.critic_optim = optimizers
@property
def lr_schedulers(self):
if hasattr(self, "actor_lr_sch") and hasattr(self, "critic_lr_sch"):
return [self.actor_lr_sch, self.critic_lr_sch]
return []
[docs] def act(self, state: Dict[str, Any], *_, **__):
"""
Use actor network to give a policy to the current state.
Returns:
Anything produced by actor.
"""
# No need to safe_return because the number of
# returned values is always more than one
return safe_call(self.actor, state)
def _eval_act(self, state: Dict[str, Any], action: Dict[str, Any], *_, **__):
"""
Use actor network to evaluate the log-likelihood of a given
action in the current state.
Returns:
Anything produced by actor.
"""
return safe_call(self.actor, state, action)
def _criticize(self, state: Dict[str, Any], *_, **__):
"""
Use critic network to evaluate current value.
Returns:
Value of shape ``[batch_size, 1]``
"""
return safe_call(self.critic, state)[0]
[docs] def store_episode(self, episode: List[Union[Transition, Dict]]):
"""
Add a full episode of transition samples to the replay buffer.
"value" and "gae" are automatically calculated.
"""
episode[-1]["value"] = episode[-1]["reward"]
# calculate value for each transition
for i in reversed(range(1, len(episode))):
episode[i - 1]["value"] = (
episode[i]["value"] * self.discount + episode[i - 1]["reward"]
)
# calculate advantage
if self.gae_lambda == 1.0:
for trans in episode:
trans["gae"] = trans["value"] - self._criticize(trans["state"]).item()
elif self.gae_lambda == 0.0:
for trans in episode:
trans["gae"] = (
trans["reward"]
+ self.discount
* (1 - float(trans["terminal"]))
* self._criticize(trans["next_state"]).item()
- self._criticize(trans["state"]).item()
)
else:
last_critic_value = 0
last_gae = 0
for trans in reversed(episode):
critic_value = self._criticize(trans["state"]).item()
gae_delta = (
trans["reward"]
+ self.discount * last_critic_value * (1 - float(trans["terminal"]))
- critic_value
)
last_critic_value = critic_value
last_gae = trans["gae"] = (
last_gae
* self.discount
* (1 - float(trans["terminal"]))
* self.gae_lambda
+ gae_delta
)
self.replay_buffer.store_episode(
episode,
required_attrs=(
"state",
"action",
"next_state",
"reward",
"value",
"gae",
"terminal",
),
)
[docs] def update(
self, update_value=True, update_policy=True, concatenate_samples=True, **__
):
"""
Update network weights by sampling from buffer. Buffer
will be cleared after update is finished.
Args:
update_value: Whether update the Q network.
update_policy: Whether update the actor network.
concatenate_samples: Whether concatenate the samples.
Returns:
mean value of estimated policy value, value loss
"""
sum_act_loss = 0
sum_value_loss = 0
self.actor.train()
self.critic.train()
for _ in range(self.actor_update_times):
# sample a batch
batch_size, (state, action, advantage) = self.replay_buffer.sample_batch(
self.batch_size,
sample_method="random_unique",
concatenate=concatenate_samples,
sample_attrs=["state", "action", "gae"],
additional_concat_custom_attrs=["gae"],
)
# normalize advantage
if self.normalize_advantage:
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-6)
if self.entropy_weight is not None:
__, action_log_prob, new_action_entropy, *_ = self._eval_act(
state, action
)
else:
__, action_log_prob, *_ = self._eval_act(state, action)
new_action_entropy = None
action_log_prob = action_log_prob.view(batch_size, 1)
# calculate policy loss
act_policy_loss = -(action_log_prob * advantage.type_as(action_log_prob))
if new_action_entropy is not None:
act_policy_loss += self.entropy_weight * new_action_entropy.mean()
act_policy_loss = act_policy_loss.mean()
sum_act_loss += act_policy_loss.item()
if self.visualize:
self.visualize_model(act_policy_loss, "actor", self.visualize_dir)
# Update actor network
if update_policy:
self.actor.zero_grad()
self._backward(act_policy_loss)
nn.utils.clip_grad_norm_(self.actor.parameters(), self.gradient_max)
self.actor_optim.step()
for _ in range(self.critic_update_times):
# sample a batch
batch_size, (state, target_value) = self.replay_buffer.sample_batch(
self.batch_size,
sample_method="random_unique",
concatenate=concatenate_samples,
sample_attrs=["state", "value"],
additional_concat_custom_attrs=["value"],
)
# calculate value loss
value = self._criticize(state)
value_loss = (
self.criterion(target_value.type_as(value), value) * self.value_weight
)
sum_value_loss += value_loss.item()
if self.visualize:
self.visualize_model(value_loss, "critic", self.visualize_dir)
# Update critic network
if update_value:
self.critic.zero_grad()
self._backward(value_loss)
nn.utils.clip_grad_norm_(self.critic.parameters(), self.gradient_max)
self.critic_optim.step()
self.replay_buffer.clear()
self.actor.eval()
self.critic.eval()
return (
-sum_act_loss / self.actor_update_times,
sum_value_loss / self.critic_update_times,
)
[docs] def update_lr_scheduler(self):
"""
Update learning rate schedulers.
"""
if hasattr(self, "actor_lr_sch"):
self.actor_lr_sch.step()
if hasattr(self, "critic_lr_sch"):
self.critic_lr_sch.step()
[docs] @classmethod
def generate_config(cls, config: Union[Dict[str, Any], Config]):
default_values = {
"models": ["Actor", "Critic"],
"model_args": ((), ()),
"model_kwargs": ({}, {}),
"optimizer": "Adam",
"criterion": "MSELoss",
"criterion_args": (),
"criterion_kwargs": {},
"lr_scheduler": None,
"lr_scheduler_args": None,
"lr_scheduler_kwargs": None,
"batch_size": 100,
"actor_update_times": 5,
"critic_update_times": 10,
"actor_learning_rate": 0.001,
"critic_learning_rate": 0.001,
"entropy_weight": None,
"value_weight": 0.5,
"gradient_max": np.inf,
"gae_lambda": 1.0,
"discount": 0.99,
"normalize_advantage": True,
"replay_size": 500000,
"replay_device": "cpu",
"replay_buffer": None,
"visualize": False,
"visualize_dir": "",
}
config = deepcopy(config)
config["frame"] = "A2C"
if "frame_config" not in config:
config["frame_config"] = default_values
else:
config["frame_config"] = {**config["frame_config"], **default_values}
return config
[docs] @classmethod
def init_from_config(
cls,
config: Union[Dict[str, Any], Config],
model_device: Union[str, t.device] = "cpu",
):
f_config = deepcopy(config["frame_config"])
models = assert_and_get_valid_models(f_config["models"])
model_args = f_config["model_args"]
model_kwargs = f_config["model_kwargs"]
models = [
m(*arg, **kwarg).to(model_device)
for m, arg, kwarg in zip(models, model_args, model_kwargs)
]
optimizer = assert_and_get_valid_optimizer(f_config["optimizer"])
criterion = assert_and_get_valid_criterion(f_config["criterion"])(
*f_config["criterion_args"], **f_config["criterion_kwargs"]
)
lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler(
f_config["lr_scheduler"]
)
f_config["optimizer"] = optimizer
f_config["criterion"] = criterion
f_config["lr_scheduler"] = lr_scheduler
frame = cls(*models, **f_config)
return frame