Source code for machin.frame.algorithms.rainbow

from machin.frame.buffers.prioritized_buffer import PrioritizedBuffer
# pylint: disable=wildcard-import, unused-wildcard-import
from .dqn import *


[docs]class RAINBOW(DQN): """ RAINBOW DQN framework. """ def __init__(self, qnet: Union[NeuralNetworkModule, nn.Module], qnet_target: Union[NeuralNetworkModule, nn.Module], optimizer, value_min, value_max, *_, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple] = None, lr_scheduler_kwargs: Tuple[Dict] = None, batch_size: int = 100, update_rate: float = 0.001, learning_rate: float = 0.001, discount: float = 0.99, gradient_max: float = np.inf, reward_future_steps: int = 3, replay_size: int = 500000, replay_device: Union[str, t.device] = "cpu", replay_buffer: Buffer = None, visualize: bool = False, visualize_dir: str = "", **__): """ RAINBOW framework is described in `this <https://arxiv.org/abs/1710.02298>`__ essay. Note: In the RAINBOW framework, the output shape of your q network must be ``[batch_size, action_num, atom_num]`` when given a state of shape ``[batch_size, action_dim]``. And the last dimension **must be soft-maxed**. Atom number is the number of segments of your q value domain. See Also: :class:`.DQN` Args: qnet: Q network module. qnet_target: Target Q network module. optimizer: Optimizer used to optimize ``actor`` and ``critic``. value_min: Minimum of value domain. value_max: Maximum of value domain. learning_rate: Learning rate of the optimizer, not compatible with ``lr_scheduler``. lr_scheduler: Learning rate scheduler of ``optimizer``. lr_scheduler_args: Arguments of the learning rate scheduler. lr_scheduler_kwargs: Keyword arguments of the learning rate scheduler. batch_size: Batch size used during training. update_rate: :math:`\\tau` used to update target networks. Target parameters are updated as: :math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)` discount: :math:`\\gamma` used in the bellman function. reward_future_steps: Number of future steps to be considered when the framework calculates value from reward. replay_size: Replay buffer size. Not compatible with ``replay_buffer``. replay_device: Device where the replay buffer locates on, Not compatible with ``replay_buffer``. replay_buffer: Custom replay buffer. mode: one of ``"vanilla", "fixed_target", "double"``. visualize: Whether visualize the network flow in the first pass. """ super(RAINBOW, self).__init__( qnet, qnet_target, optimizer, lambda: None, learning_rate=learning_rate, lr_scheduler=lr_scheduler, lr_scheduler_args=lr_scheduler_args, lr_scheduler_kwargs=lr_scheduler_kwargs, batch_size=batch_size, update_rate=update_rate, discount=discount, gradient_max=gradient_max, replay_size=replay_size, replay_device=replay_device, replay_buffer=(PrioritizedBuffer(replay_size, replay_device) if replay_buffer is None else replay_buffer), visualize=visualize, visualize_dir=visualize_dir ) self.v_min = value_min self.v_max = value_max self.reward_future_steps = reward_future_steps
[docs] def act_discrete(self, state: Dict[str, Any], use_target: bool = False, **__): # DOC INHERITED # q value distribution of each action # shape: [batch_size, action_num, atom_num] if use_target: q_dist, *others = safe_call(self.qnet_target, state) else: q_dist, *others = safe_call(self.qnet, state) atom_num = q_dist.shape[-1] # support vector, shape: [1, atom_num] q_dist_support = t.linspace(self.v_min, self.v_max, atom_num) \ .view(1, -1) # q value of each action, shape: [batch_size, action_num] q_value = t.sum(q_dist_support.to(q_dist.device) * q_dist, dim=-1) result = t.argmax(q_value, dim=1).view(-1, 1) if len(others) == 0: return result else: return (result, *others)
[docs] def act_discrete_with_noise(self, state: Dict[str, Any], use_target: bool = False, **__): # DOC INHERITED # q value distribution of each action # shape: [batch_size, action_num, atom_num] if use_target: q_dist, *others = safe_call(self.qnet_target, state) else: q_dist, *others = safe_call(self.qnet, state) atom_num = q_dist.shape[-1] # support vector, shape: [1, atom_num] q_dist_support = t.linspace(self.v_min, self.v_max, atom_num) \ .view(1, -1) # q value of each action, shape: [batch_size, action_num] q_value = t.sum(q_dist_support.to(q_dist.device) * q_dist, dim=-1) result = t.softmax(q_value, dim=1) dist = Categorical(result) batch_size = result.shape[0] sample = dist.sample([batch_size]) if len(others) == 0: return sample else: return (sample, *others)
[docs] def store_transition(self, transition: Union[Transition, Dict]): """ Add a transition sample to the replay buffer. Not suggested, since you will have to calculate "value" by yourself. """ self.replay_buffer.append(transition, required_attrs=( "state", "action", "next_state", "reward", "value", "terminal" ))
[docs] def store_episode(self, episode: List[Union[Transition, Dict]]): """ Add a full episode of transition samples to the replay buffer. "value" is automatically calculated. """ episode[-1]["value"] = episode[-1]["reward"] # calculate (truncated) n step value for each transition for i in reversed(range(len(episode))): value_sum = 0 # for (virtual) transitions beyond the terminal transition, # using "min" to ignore them is equivalent to setting their # rewards as zero for j in reversed(range(min( self.reward_future_steps, len(episode) - i ))): value_sum = (value_sum * self.discount + episode[i + j]["reward"]) episode[i]["value"] = value_sum for trans in episode: self.replay_buffer.append(trans, required_attrs=( "state", "action", "next_state", "reward", "value", "terminal" ))
[docs] def update(self, update_value=True, update_target=True, concatenate_samples=True, **__): # DOC INHERITED # pylint: disable=invalid-name self.qnet.train() (batch_size, (state, action, value, next_state, terminal, others), index, is_weight) = \ self.replay_buffer.sample_batch(self.batch_size, concatenate_samples, sample_attrs=[ "state", "action", "value", "next_state", "terminal", "*" ], additional_concat_attrs=[ "value" ]) # q_dist is the distribution of q values q_dist = self._criticize(state).cpu() atom_num = q_dist.shape[-1] action = self.action_get_function(action).to(device="cpu", dtype=t.long).flatten() # shape: [batch_size, atom_num] q_dist = q_dist[range(batch_size), action] # support vector, shape: [atom_num] q_dist_support = t.linspace(self.v_min, self.v_max, atom_num) with t.no_grad(): target_next_q_dist = self._criticize(next_state, True).cpu() next_action = (self.act_discrete(next_state).flatten() .to(device="cpu", dtype=t.long)) # shape: [batch_size, atom_num] target_next_q_dist = target_next_q_dist[range(batch_size), next_action] # shape: [1, atom_num] q_dist_support = q_dist_support.unsqueeze(dim=0) # T_z is the bellman update for atom z_j # shape: [batch_size, atom_num] T_z = self.reward_function( value.cpu(), self.discount ** self.reward_future_steps, q_dist_support, terminal.cpu(), others ) # 1e-6 is used to make sure that l != u when T_z == v_min or v_max T_z = T_z.clamp(self.v_min + 1e-6, self.v_max - 1e-6) # delta_z is the interval length of each atom delta_z = (self.v_max - self.v_min) / (atom_num - 1.0) # b is the normalized distance of T_z to v_min, # l and u are upper and lower atom indexes # b, l, u shape: [batch_size, atom_num] b = (T_z - self.v_min) / delta_z l, u = b.floor(), b.ceil() # idx shape: [batch_size * atom_num] # dist shape: [batch_size, atom_num] # weight shape: [batch_size * atom_num] l_idx, l_dist = l.long().view(-1), b - l u_idx, u_dist = u.long().view(-1), u - b l_weight = (u_dist * target_next_q_dist).view(-1) u_weight = (l_dist * target_next_q_dist).view(-1) # offset is used to perform row-wise index add, since we can only # perform index add on one dimension, we must flatten the whole # distribution and then add. offset = (t.arange(0, batch_size * atom_num, atom_num) .view(-1, 1) .expand(batch_size, atom_num) .flatten()) # distribute T_z probability to its nearest upper # and lower atom neighbors, using its distance to them. # shape: [batch_size * atom_num] -> [batch_size, atom_num] # Note: index_add_ on CUDA uses atomicAdd, will cause # rounding errors and be a source of noise. target_dist = t.zeros([batch_size * atom_num]) target_dist.index_add_(dim=0, index=l_idx + offset, source=l_weight) target_dist.index_add_(dim=0, index=u_idx + offset, source=u_weight) target_dist = target_dist.view(batch_size, atom_num) # target_dist is equivalent to y_i in original dqn # division in KL divergence is ignored because target_dist # is a constant? But this modification do prevents the 0/0 situation. # 1e-6 is used to improve numerical stability and prevent nan value_loss = -(target_dist * (q_dist + 1e-6).log()) value_loss = value_loss.sum(dim=1) abs_error = (t.abs(value_loss) + 1e-6).flatten().detach().numpy() self.replay_buffer.update_priority(abs_error, index) value_loss = (value_loss * t.from_numpy(is_weight).view([batch_size, 1])).mean() if self.visualize: self.visualize_model(value_loss, "qnet", self.visualize_dir) if update_value: self.qnet.zero_grad() value_loss.backward() nn.utils.clip_grad_norm_( self.qnet.parameters(), self.grad_max ) self.qnet_optim.step() # Update target Q network if update_target: soft_update(self.qnet_target, self.qnet, self.update_rate) self.qnet.eval() # use .item() to prevent memory leakage return value_loss.item()