from machin.frame.buffers.prioritized_buffer import PrioritizedBuffer
# pylint: disable=wildcard-import, unused-wildcard-import
from .dqn import *
[docs]class RAINBOW(DQN):
"""
RAINBOW DQN framework.
"""
def __init__(
self,
qnet: Union[NeuralNetworkModule, nn.Module],
qnet_target: Union[NeuralNetworkModule, nn.Module],
optimizer,
value_min,
value_max,
*_,
lr_scheduler: Callable = None,
lr_scheduler_args: Tuple[Tuple] = None,
lr_scheduler_kwargs: Tuple[Dict] = None,
batch_size: int = 100,
epsilon_decay: float = 0.9999,
update_rate: float = 0.001,
update_steps: Union[int, None] = None,
learning_rate: float = 0.001,
discount: float = 0.99,
gradient_max: float = np.inf,
reward_future_steps: int = 3,
replay_size: int = 500000,
replay_device: Union[str, t.device] = "cpu",
replay_buffer: Buffer = None,
visualize: bool = False,
visualize_dir: str = "",
**__
):
"""
RAINBOW framework is described in
`this <https://arxiv.org/abs/1710.02298>`__ essay.
Note:
In the RAINBOW framework, the output shape of your q network
must be ``[batch_size, action_num, atom_num]`` when given a
state of shape ``[batch_size, action_dim]``. And the last
dimension **must be soft-maxed**. Atom number is the number of
segments of your q value domain.
See Also:
:class:`.DQN`
Args:
qnet: Q network module.
qnet_target: Target Q network module.
optimizer: Optimizer used to optimize ``actor`` and ``critic``.
value_min: Minimum of value domain.
value_max: Maximum of value domain.
learning_rate: Learning rate of the optimizer, not compatible with
``lr_scheduler``.
lr_scheduler: Learning rate scheduler of ``optimizer``.
lr_scheduler_args: Arguments of the learning rate scheduler.
lr_scheduler_kwargs: Keyword arguments of the learning
rate scheduler.
batch_size: Batch size used during training.
epsilon_decay: Epsilon decay rate per acting with noise step.
``epsilon`` attribute is multiplied with this every time
``act_discrete_with_noise`` is called.
update_rate: :math:`\\tau` used to update target networks.
Target parameters are updated as:
:math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)`
update_steps: Training step number used to update target networks.
discount: :math:`\\gamma` used in the bellman function.
reward_future_steps: Number of future steps to be considered when
the framework calculates value from reward.
replay_size: Replay buffer size. Not compatible with
``replay_buffer``.
replay_device: Device where the replay buffer locates on, Not
compatible with ``replay_buffer``.
replay_buffer: Custom replay buffer.
mode: one of ``"vanilla", "fixed_target", "double"``.
visualize: Whether visualize the network flow in the first pass.
"""
super().__init__(
qnet,
qnet_target,
optimizer,
lambda: None,
learning_rate=learning_rate,
lr_scheduler=lr_scheduler,
lr_scheduler_args=lr_scheduler_args,
lr_scheduler_kwargs=lr_scheduler_kwargs,
batch_size=batch_size,
epsilon_decay=epsilon_decay,
update_rate=update_rate,
update_steps=update_steps,
discount=discount,
gradient_max=gradient_max,
replay_size=replay_size,
replay_device=replay_device,
replay_buffer=(
PrioritizedBuffer(replay_size, replay_device)
if replay_buffer is None
else replay_buffer
),
visualize=visualize,
visualize_dir=visualize_dir,
)
self.v_min = value_min
self.v_max = value_max
self.reward_future_steps = reward_future_steps
[docs] def act_discrete(self, state: Dict[str, Any], use_target: bool = False, **__):
# DOC INHERITED
# q value distribution of each action
# shape: [batch_size, action_num, atom_num]
if use_target:
q_dist, *others = safe_call(self.qnet_target, state)
else:
q_dist, *others = safe_call(self.qnet, state)
atom_num = q_dist.shape[-1]
# support vector, shape: [1, atom_num]
q_dist_support = t.linspace(self.v_min, self.v_max, atom_num).view(1, -1)
# q value of each action, shape: [batch_size, action_num]
q_value = t.sum(q_dist_support.type_as(q_dist) * q_dist, dim=-1)
result = t.argmax(q_value, dim=1).view(-1, 1)
if len(others) == 0:
return result
else:
return (result, *others)
[docs] def act_discrete_with_noise(
self,
state: Dict[str, Any],
use_target: bool = False,
decay_epsilon: bool = True,
**__
):
# DOC INHERITED
# q value distribution of each action
# shape: [batch_size, action_num, atom_num]
if use_target:
q_dist, *others = safe_call(self.qnet_target, state)
else:
q_dist, *others = safe_call(self.qnet, state)
atom_num = q_dist.shape[-1]
# support vector, shape: [1, atom_num]
q_dist_support = t.linspace(self.v_min, self.v_max, atom_num).view(1, -1)
# q value of each action, shape: [batch_size, action_num]
q_value = t.sum(q_dist_support.type_as(q_dist) * q_dist, dim=-1)
action_dim = q_value.shape[1]
result = t.argmax(q_value, dim=1).view(-1, 1)
if t.rand([1]).item() < self.epsilon:
result = t.randint(0, action_dim, [result.shape[0], 1])
if decay_epsilon:
self.epsilon *= self.epsilon_decay
if len(others) == 0:
return result
else:
return (result, *others)
[docs] def store_episode(self, episode: List[Union[Transition, Dict]]):
"""
Add a full episode of transition samples to the replay buffer.
"value" is automatically calculated.
"""
episode[-1]["value"] = episode[-1]["reward"]
# calculate (truncated) n step value for each transition
for i in reversed(range(len(episode))):
value_sum = 0
# for (virtual) transitions beyond the terminal transition,
# using "min" to ignore them is equivalent to setting their
# rewards as zero
for j in reversed(range(min(self.reward_future_steps, len(episode) - i))):
value_sum = value_sum * self.discount + episode[i + j]["reward"]
episode[i]["value"] = value_sum
self.replay_buffer.store_episode(
episode,
required_attrs=(
"state",
"action",
"next_state",
"reward",
"value",
"terminal",
),
)
[docs] def update(
self, update_value=True, update_target=True, concatenate_samples=True, **__
):
# DOC INHERITED
# pylint: disable=invalid-name
self.qnet.train()
(
batch_size,
(state, action, value, next_state, terminal, others),
index,
is_weight,
) = self.replay_buffer.sample_batch(
self.batch_size,
concatenate_samples,
sample_attrs=["state", "action", "value", "next_state", "terminal", "*"],
additional_concat_custom_attrs=["value"],
)
# q_dist is the distribution of q values
q_dist = self._criticize(state).cpu()
atom_num = q_dist.shape[-1]
action = (
self.action_get_function(action).to(device="cpu", dtype=t.long).flatten()
)
# shape: [batch_size, atom_num]
q_dist = q_dist[range(batch_size), action]
# support vector, shape: [atom_num]
q_dist_support = t.linspace(self.v_min, self.v_max, atom_num)
with t.no_grad():
target_next_q_dist = self._criticize(next_state, True).cpu()
next_action = (
self.act_discrete(next_state).flatten().to(device="cpu", dtype=t.long)
)
# shape: [batch_size, atom_num]
target_next_q_dist = target_next_q_dist[range(batch_size), next_action]
# shape: [1, atom_num]
q_dist_support = q_dist_support.unsqueeze(dim=0)
# T_z is the bellman update for atom z_j
# shape: [batch_size, atom_num]
T_z = self.reward_function(
value.cpu(),
self.discount ** self.reward_future_steps,
q_dist_support,
terminal.cpu(),
others,
)
# 1e-6 is used to make sure that l != u when T_z == v_min or v_max
T_z = T_z.clamp(self.v_min + 1e-6, self.v_max - 1e-6)
# delta_z is the interval length of each atom
delta_z = (self.v_max - self.v_min) / (atom_num - 1.0)
# b is the normalized distance of T_z to v_min,
# l and u are upper and lower atom indexes
# b, l, u shape: [batch_size, atom_num]
b = (T_z - self.v_min) / delta_z
l, u = b.floor(), b.ceil()
# idx shape: [batch_size * atom_num]
# dist shape: [batch_size, atom_num]
# weight shape: [batch_size * atom_num]
l_idx, l_dist = l.long().view(-1), b - l
u_idx, u_dist = u.long().view(-1), u - b
l_weight = (u_dist * target_next_q_dist).view(-1)
u_weight = (l_dist * target_next_q_dist).view(-1)
# offset is used to perform row-wise index add, since we can only
# perform index add on one dimension, we must flatten the whole
# distribution and then add.
offset = (
t.arange(0, batch_size * atom_num, atom_num)
.view(-1, 1)
.expand(batch_size, atom_num)
.flatten()
)
# distribute T_z probability to its nearest upper
# and lower atom neighbors, using its distance to them.
# shape: [batch_size * atom_num] -> [batch_size, atom_num]
# Note: index_add_ on CUDA uses atomicAdd, will cause
# rounding errors and be a source of noise.
target_dist = t.zeros([batch_size * atom_num], dtype=l_weight.dtype)
target_dist.index_add_(dim=0, index=l_idx + offset, source=l_weight)
target_dist.index_add_(dim=0, index=u_idx + offset, source=u_weight)
target_dist = target_dist.view(batch_size, atom_num)
# target_dist is equivalent to y_i in original dqn
# division in KL divergence is ignored because target_dist
# is a constant? But this modification do prevents the 0/0 situation.
# 1e-6 is used to improve numerical stability and prevent nan
value_loss = -(target_dist * (q_dist + 1e-6).log())
value_loss = value_loss.sum(dim=1)
abs_error = (t.abs(value_loss) + 1e-6).flatten().detach().numpy()
self.replay_buffer.update_priority(abs_error, index)
value_loss = (value_loss * t.from_numpy(is_weight).view([batch_size, 1])).mean()
if self.visualize:
self.visualize_model(value_loss, "qnet", self.visualize_dir)
if update_value:
self.qnet.zero_grad()
self._backward(value_loss)
nn.utils.clip_grad_norm_(self.qnet.parameters(), self.grad_max)
self.qnet_optim.step()
# Update target Q network
if update_target:
if self.update_rate is not None:
soft_update(self.qnet_target, self.qnet, self.update_rate)
else:
self._update_counter += 1
if self._update_counter % self.update_steps == 0:
hard_update(self.qnet_target, self.qnet)
self.qnet.eval()
# use .item() to prevent memory leakage
return value_loss.item()
[docs] @classmethod
def generate_config(cls, config: Union[Dict[str, Any], Config]):
config = DQN.generate_config(config)
config["frame"] = "RAINBOW"
config["frame_config"]["value_min"] = -1.0
config["frame_config"]["value_max"] = 1.0
config["frame_config"]["reward_future_steps"] = 3
return config