from .a2c import *
[docs]class PPO(A2C):
"""
PPO framework.
"""
def __init__(
self,
actor: Union[NeuralNetworkModule, nn.Module],
critic: Union[NeuralNetworkModule, nn.Module],
optimizer: Callable,
criterion: Callable,
*_,
lr_scheduler: Callable = None,
lr_scheduler_args: Tuple[Tuple, Tuple] = (),
lr_scheduler_kwargs: Tuple[Dict, Dict] = (),
batch_size: int = 100,
actor_update_times: int = 5,
critic_update_times: int = 10,
actor_learning_rate: float = 0.001,
critic_learning_rate: float = 0.001,
entropy_weight: float = None,
value_weight: float = 0.5,
surrogate_loss_clip: float = 0.2,
gradient_max: float = np.inf,
gae_lambda: float = 1.0,
discount: float = 0.99,
normalize_advantage: bool = True,
replay_size: int = 500000,
replay_device: Union[str, t.device] = "cpu",
replay_buffer: Buffer = None,
visualize: bool = False,
visualize_dir: str = "",
**__
):
"""
See Also:
:class:`.A2C`
Args:
actor: Actor network module.
critic: Critic network module.
optimizer: Optimizer used to optimize ``actor`` and ``critic``.
criterion: Criterion used to evaluate the value loss.
lr_scheduler: Learning rate scheduler of ``optimizer``.
lr_scheduler_args: Arguments of the learning rate scheduler.
lr_scheduler_kwargs: Keyword arguments of the learning
rate scheduler.
batch_size: Batch size used during training.
actor_update_times: Times to update actor in ``update()``.
critic_update_times: Times to update critic in ``update()``.
actor_learning_rate: Learning rate of the actor optimizer,
not compatible with ``lr_scheduler``.
critic_learning_rate: Learning rate of the critic optimizer,
not compatible with ``lr_scheduler``.
entropy_weight: Weight of entropy in your loss function, a positive
entropy weight will minimize entropy, while a negative one will
maximize entropy.
value_weight: Weight of critic value loss.
surrogate_loss_clip: Surrogate loss clipping parameter in PPO.
gradient_max: Maximum gradient.
gae_lambda: :math:`\\lambda` used in generalized advantage
estimation.
discount: :math:`\\gamma` used in the bellman function.
normalize_advantage: Whether to normalize sampled advantage values in
the batch.
replay_size: Replay buffer size. Not compatible with
``replay_buffer``.
replay_device: Device where the replay buffer locates on, Not
compatible with ``replay_buffer``.
replay_buffer: Custom replay buffer.
visualize: Whether visualize the network flow in the first pass.
visualize_dir: Visualized graph save directory.
"""
super().__init__(
actor,
critic,
optimizer,
criterion,
lr_scheduler=lr_scheduler,
lr_scheduler_args=lr_scheduler_args,
lr_scheduler_kwargs=lr_scheduler_kwargs,
batch_size=batch_size,
actor_update_times=actor_update_times,
critic_update_times=critic_update_times,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
entropy_weight=entropy_weight,
value_weight=value_weight,
gradient_max=gradient_max,
gae_lambda=gae_lambda,
discount=discount,
normalize_advantage=normalize_advantage,
replay_size=replay_size,
replay_device=replay_device,
replay_buffer=replay_buffer,
visualize=visualize,
visualize_dir=visualize_dir,
)
self.surr_clip = surrogate_loss_clip
[docs] def update(
self, update_value=True, update_policy=True, concatenate_samples=True, **__
):
# DOC INHERITED
sum_act_policy_loss = 0
sum_value_loss = 0
# create a temporary copy of the not-updated actor
tmp_actor = deepcopy(self.actor)
self.actor.train()
self.critic.train()
for _ in range(self.actor_update_times):
# sample a batch
batch_size, (state, action, advantage) = self.replay_buffer.sample_batch(
self.batch_size,
sample_method="random_unique",
concatenate=concatenate_samples,
sample_attrs=["state", "action", "gae"],
additional_concat_custom_attrs=["gae"],
)
# normalize advantage
if self.normalize_advantage:
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-6)
# Infer original action log probability
# TODO:
# This temporary fix is not efficient, maybe requires
# PPO store API alternation.
with t.no_grad():
self.actor, tmp_actor = tmp_actor, self.actor
__, action_log_prob, *_ = self._eval_act(state, action)
self.actor, tmp_actor = tmp_actor, self.actor
action_log_prob = action_log_prob.view(batch_size, 1)
if self.entropy_weight is not None:
__, new_action_log_prob, new_action_entropy, *_ = self._eval_act(
state, action
)
else:
__, new_action_log_prob, *_ = self._eval_act(state, action)
new_action_entropy = None
new_action_log_prob = new_action_log_prob.view(batch_size, 1)
# calculate surrogate loss
# The function of this process is ignoring actions that are not
# likely to be produced in current actor policy distribution,
# Because in each update, the old policy distribution diverges
# from the current distribution more and more.
sim_ratio = t.exp(new_action_log_prob - action_log_prob)
advantage = advantage.type_as(sim_ratio)
surr_loss_1 = sim_ratio * advantage
surr_loss_2 = (
t.clamp(sim_ratio, 1 - self.surr_clip, 1 + self.surr_clip) * advantage
)
# calculate policy loss using surrogate loss
act_policy_loss = -t.min(surr_loss_1, surr_loss_2)
if new_action_entropy is not None:
act_policy_loss += self.entropy_weight * new_action_entropy.mean()
act_policy_loss = act_policy_loss.mean()
if self.visualize:
self.visualize_model(act_policy_loss, "actor", self.visualize_dir)
# Update actor network
if update_policy:
self.actor.zero_grad()
self._backward(act_policy_loss)
nn.utils.clip_grad_norm_(self.actor.parameters(), self.gradient_max)
self.actor_optim.step()
sum_act_policy_loss += act_policy_loss.item()
for _ in range(self.critic_update_times):
# sample a batch
batch_size, (state, target_value) = self.replay_buffer.sample_batch(
self.batch_size,
sample_method="random_unique",
concatenate=concatenate_samples,
sample_attrs=["state", "value"],
additional_concat_custom_attrs=["value"],
)
# calculate value loss
value = self._criticize(state)
value_loss = (
self.criterion(target_value.type_as(value), value) * self.value_weight
)
if self.visualize:
self.visualize_model(value_loss, "critic", self.visualize_dir)
# Update critic network
if update_value:
self.critic.zero_grad()
self._backward(value_loss)
nn.utils.clip_grad_norm_(self.critic.parameters(), self.gradient_max)
self.critic_optim.step()
sum_value_loss += value_loss.item()
self.replay_buffer.clear()
self.actor.eval()
self.critic.eval()
return (
-sum_act_policy_loss / self.actor_update_times,
sum_value_loss / self.critic_update_times,
)
[docs] @classmethod
def generate_config(cls, config: Union[Dict[str, Any], Config]):
config = A2C.generate_config(config)
config["frame"] = "PPO"
config["frame_config"]["surrogate_loss_clip"] = 0.2
return config