from copy import deepcopy
from .a2c import *
[docs]class PPO(A2C):
"""
PPO framework.
"""
def __init__(self,
actor: Union[NeuralNetworkModule, nn.Module],
critic: Union[NeuralNetworkModule, nn.Module],
optimizer: Callable,
criterion: Callable,
*_,
lr_scheduler: Callable = None,
lr_scheduler_args: Tuple[Tuple, Tuple] = (),
lr_scheduler_kwargs: Tuple[Dict, Dict] = (),
batch_size: int = 100,
actor_update_times: int = 5,
critic_update_times: int = 10,
actor_learning_rate: float = 0.001,
critic_learning_rate: float = 0.001,
entropy_weight: float = None,
value_weight: float = 0.5,
surrogate_loss_clip: float = 0.2,
gradient_max: float = np.inf,
gae_lambda: float = 1.0,
discount: float = 0.99,
normalize_advantage: bool = True,
replay_size: int = 500000,
replay_device: Union[str, t.device] = "cpu",
replay_buffer: Buffer = None,
visualize: bool = False,
visualize_dir: str = "",
**__):
"""
See Also:
:class:`.A2C`
Args:
actor: Actor network module.
critic: Critic network module.
optimizer: Optimizer used to optimize ``actor`` and ``critic``.
criterion: Criterion used to evaluate the value loss.
lr_scheduler: Learning rate scheduler of ``optimizer``.
lr_scheduler_args: Arguments of the learning rate scheduler.
lr_scheduler_kwargs: Keyword arguments of the learning
rate scheduler.
batch_size: Batch size used during training.
actor_update_times: Times to update actor in ``update()``.
critic_update_times: Times to update critic in ``update()``.
actor_learning_rate: Learning rate of the actor optimizer,
not compatible with ``lr_scheduler``.
critic_learning_rate: Learning rate of the critic optimizer,
not compatible with ``lr_scheduler``.
entropy_weight: Weight of entropy in your loss function, a positive
entropy weight will minimize entropy, while a negative one will
maximize entropy.
value_weight: Weight of critic value loss.
surrogate_loss_clip: Surrogate loss clipping parameter in PPO.
gradient_max: Maximum gradient.
gae_lambda: :math:`\\lambda` used in generalized advantage
estimation.
discount: :math:`\\gamma` used in the bellman function.
replay_size: Replay buffer size. Not compatible with
``replay_buffer``.
replay_device: Device where the replay buffer locates on, Not
compatible with ``replay_buffer``.
replay_buffer: Custom replay buffer.
visualize: Whether visualize the network flow in the first pass.
visualize_dir: Visualized graph save directory.
"""
super(PPO, self).__init__(actor, critic, optimizer, criterion,
lr_scheduler=lr_scheduler,
lr_scheduler_args=lr_scheduler_args,
lr_scheduler_kwargs=lr_scheduler_kwargs,
batch_size=batch_size,
actor_update_times=actor_update_times,
critic_update_times=critic_update_times,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
entropy_weight=entropy_weight,
value_weight=value_weight,
gradient_max=gradient_max,
gae_lambda=gae_lambda,
discount=discount,
normalize_advantage=normalize_advantage,
replay_size=replay_size,
replay_device=replay_device,
replay_buffer=replay_buffer,
visualize=visualize,
visualize_dir=visualize_dir)
self.surr_clip = surrogate_loss_clip
[docs] def update(self,
update_value=True,
update_policy=True,
concatenate_samples=True,
**__):
# DOC INHERITED
sum_act_policy_loss = 0
sum_value_loss = 0
# create a temporary copy of the not-updated actor
tmp_actor = deepcopy(self.actor)
self.actor.train()
self.critic.train()
for _ in range(self.actor_update_times):
# sample a batch
batch_size, (state, action, advantage) = \
self.replay_buffer.sample_batch(self.batch_size,
sample_method="random_unique",
concatenate=concatenate_samples,
sample_attrs=[
"state", "action", "gae"],
additional_concat_attrs=[
"gae"
])
# normalize advantage
if self.normalize_advantage:
advantage = ((advantage - advantage.mean()) /
(advantage.std() + 1e-6))
# Infer original action log probability
# TODO:
# This temporary fix is not efficient, maybe requires
# PPO store API alternation.
with t.no_grad():
self.actor, tmp_actor = tmp_actor, self.actor
__, action_log_prob, *_ = self._eval_act(state, action)
self.actor, tmp_actor = tmp_actor, self.actor
action_log_prob = action_log_prob.view(batch_size, 1)
if self.entropy_weight is not None:
__, new_action_log_prob, new_action_entropy, *_ = \
self._eval_act(state, action)
else:
__, new_action_log_prob, *_ = \
self._eval_act(state, action)
new_action_entropy = None
new_action_log_prob = new_action_log_prob.view(batch_size, 1)
# calculate surrogate loss
# The function of this process is ignoring actions that are not
# likely to be produced in current actor policy distribution,
# Because in each update, the old policy distribution diverges
# from the current distribution more and more.
sim_ratio = t.exp(new_action_log_prob - action_log_prob)
advantage = advantage.to(sim_ratio.device)
surr_loss_1 = sim_ratio * advantage
surr_loss_2 = t.clamp(sim_ratio,
1 - self.surr_clip,
1 + self.surr_clip) * advantage
# calculate policy loss using surrogate loss
act_policy_loss = -t.min(surr_loss_1, surr_loss_2)
if new_action_entropy is not None:
act_policy_loss += (self.entropy_weight *
new_action_entropy.mean())
act_policy_loss = act_policy_loss.mean()
if self.visualize:
self.visualize_model(act_policy_loss, "actor",
self.visualize_dir)
# Update actor network
if update_policy:
self.actor.zero_grad()
act_policy_loss.backward()
nn.utils.clip_grad_norm_(
self.actor.parameters(), self.grad_max
)
self.actor_optim.step()
sum_act_policy_loss += act_policy_loss.item()
for _ in range(self.critic_update_times):
# sample a batch
batch_size, (state, target_value) = \
self.replay_buffer.sample_batch(self.batch_size,
sample_method="random_unique",
concatenate=concatenate_samples,
sample_attrs=[
"state", "value"],
additional_concat_attrs=[
"value"
])
# calculate value loss
value = self._criticize(state)
value_loss = (self.criterion(target_value.to(value.device),
value) *
self.value_weight)
if self.visualize:
self.visualize_model(value_loss, "critic",
self.visualize_dir)
# Update critic network
if update_value:
self.critic.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(
self.critic.parameters(), self.grad_max
)
self.critic_optim.step()
sum_value_loss += value_loss.item()
self.replay_buffer.clear()
self.actor.eval()
self.critic.eval()
return (-sum_act_policy_loss / self.actor_update_times,
sum_value_loss / self.critic_update_times)