from machin.frame.buffers.prioritized_buffer import PrioritizedBuffer
from machin.utils.logging import default_logger
# pylint: disable=wildcard-import, unused-wildcard-import
from .ddpg import *
[docs]class DDPGPer(DDPG):
"""
DDPG with prioritized experience replay.
Warning:
Your criterion must return a tensor of shape ``[batch_size,1]``
when given two tensors of shape ``[batch_size,1]``, since we
need to multiply the loss with importance sampling weight
element-wise.
If you are using loss modules given by pytorch. It is always
safe to use them without any modification.
"""
def __init__(self,
actor: Union[NeuralNetworkModule, nn.Module],
actor_target: Union[NeuralNetworkModule, nn.Module],
critic: Union[NeuralNetworkModule, nn.Module],
critic_target: Union[NeuralNetworkModule, nn.Module],
optimizer: Callable,
criterion,
*_,
lr_scheduler: Callable = None,
lr_scheduler_args: Tuple[Tuple, Tuple] = None,
lr_scheduler_kwargs: Tuple[Dict, Dict] = None,
batch_size: int = 100,
update_rate: float = 0.005,
actor_learning_rate: float = 0.0005,
critic_learning_rate: float = 0.001,
discount: float = 0.99,
gradient_max: float = np.inf,
replay_size: int = 500000,
replay_device: Union[str, t.device] = "cpu",
replay_buffer: Buffer = None,
visualize: bool = False,
visualize_dir: str = "",
**__):
# DOC INHERITED
super(DDPGPer, self).__init__(
actor, actor_target, critic, critic_target, optimizer, criterion,
lr_scheduler=lr_scheduler,
lr_scheduler_args=lr_scheduler_args,
lr_scheduler_kwargs=lr_scheduler_kwargs,
batch_size=batch_size,
update_rate=update_rate,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
discount=discount,
gradient_max=gradient_max,
replay_size=replay_size,
replay_device=replay_device,
replay_buffer=(PrioritizedBuffer(replay_size, replay_device)
if replay_buffer is None
else replay_buffer),
visualize=visualize,
visualize_dir=visualize_dir
)
# reduction must be None
if not hasattr(self.criterion, "reduction"):
raise RuntimeError("Criterion does not have the "
"'reduction' property")
else:
if hasattr(self.criterion, "reduction"):
# A loss defined in ``torch.nn.modules.loss``
if self.criterion.reduction != "none":
default_logger.warning(
"The reduction property of criterion is not 'none', "
"automatically corrected."
)
self.criterion.reduction = "none"
[docs] def update(self,
update_value=True,
update_policy=True,
update_target=True,
concatenate_samples=True,
**__):
# DOC INHERITED
self.actor.train()
self.critic.train()
(batch_size,
(state, action, reward, next_state, terminal, others),
index, is_weight) = \
self.replay_buffer.sample_batch(self.batch_size,
concatenate_samples,
sample_attrs=[
"state", "action",
"reward", "next_state",
"terminal", "*"
])
# Update critic network first.
# Generate value reference :math: `y_i` using target actor and
# target critic.
with t.no_grad():
next_action = self.action_transform_function(
self._act(next_state, True), next_state, others
)
next_value = self._criticize(next_state, next_action, True)
next_value = next_value.view(batch_size, -1)
y_i = self.reward_function(
reward, self.discount, next_value, terminal, others
)
# critic loss
cur_value = self._criticize(state, action)
value_loss = self.criterion(cur_value, y_i.to(cur_value.device))
value_loss = (value_loss *
t.from_numpy(is_weight).view([batch_size, 1])
.to(value_loss.device))
value_loss = value_loss.mean()
if self.visualize:
self.visualize_model(value_loss, "critic", self.visualize_dir)
if update_value:
self.critic.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(
self.critic.parameters(), self.grad_max
)
self.critic_optim.step()
# actor loss
cur_action = self.action_transform_function(
self._act(state), state, others
)
act_value = self._criticize(state, cur_action)
# "-" is applied because we want to maximize J_b(u),
# but optimizer workers by minimizing the target
act_policy_loss = -act_value.mean()
# update priority
abs_error = (t.sum(t.abs(act_value - y_i.to(act_value.device)), dim=1)
.flatten().detach().cpu().numpy())
self.replay_buffer.update_priority(abs_error, index)
if self.visualize:
self.visualize_model(act_policy_loss, "actor", self.visualize_dir)
if update_policy:
self.actor.zero_grad()
act_policy_loss.backward()
nn.utils.clip_grad_norm_(
self.actor.parameters(), self.grad_max
)
self.actor_optim.step()
# Update target networks
if update_target:
soft_update(self.actor_target, self.actor, self.update_rate)
soft_update(self.critic_target, self.critic, self.update_rate)
self.actor.eval()
self.critic.eval()
# use .item() to prevent memory leakage
return -act_policy_loss.item(), value_loss.item()