Source code for machin.frame.algorithms.dqn_per

from machin.frame.buffers.prioritized_buffer import PrioritizedBuffer
from machin.utils.logging import default_logger

# pylint: disable=wildcard-import, unused-wildcard-import
from .dqn import *


[docs]class DQNPer(DQN): """ DQN with prioritized replay. It is based on Double DQN. Warning: Your criterion must return a tensor of shape ``[batch_size,1]`` when given two tensors of shape ``[batch_size,1]``, since we need to multiply the loss with importance sampling weight element-wise. If you are using loss modules given by pytorch. It is always safe to use them without any modification. """ def __init__( self, qnet: Union[NeuralNetworkModule, nn.Module], qnet_target: Union[NeuralNetworkModule, nn.Module], optimizer: Callable, criterion: Callable, *_, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple] = None, lr_scheduler_kwargs: Tuple[Dict] = None, batch_size: int = 100, epsilon_decay: float = 0.9999, update_rate: Union[float, None] = 0.005, update_steps: Union[int, None] = None, learning_rate: float = 0.001, discount: float = 0.99, gradient_max: float = np.inf, replay_size: int = 500000, replay_device: Union[str, t.device] = "cpu", replay_buffer: Buffer = None, visualize: bool = False, visualize_dir: str = "", **__ ): # DOC INHERITED super().__init__( qnet, qnet_target, optimizer, criterion, lr_scheduler=lr_scheduler, lr_scheduler_args=lr_scheduler_args, lr_scheduler_kwargs=lr_scheduler_kwargs, batch_size=batch_size, epsilon_decay=epsilon_decay, update_rate=update_rate, update_steps=update_steps, learning_rate=learning_rate, discount=discount, gradient_max=gradient_max, replay_size=replay_size, replay_device=replay_device, replay_buffer=( PrioritizedBuffer(replay_size, replay_device) if replay_buffer is None else replay_buffer ), mode="double", visualize=visualize, visualize_dir=visualize_dir, ) # reduction must be None if not hasattr(self.criterion, "reduction"): raise RuntimeError( "Criterion does not have the " "'reduction' property, are you using a custom " "criterion?" ) else: # A loss defined in ``torch.nn.modules.loss`` if getattr(self.criterion, "reduction") != "none": default_logger.warning( "The reduction property of criterion is not 'none', " "automatically corrected." ) self.criterion.reduction = "none"
[docs] def update( self, update_value=True, update_target=True, concatenate_samples=True, **__ ): # DOC INHERITED self.qnet.train() ( batch_size, (state, action, reward, next_state, terminal, others), index, is_weight, ) = self.replay_buffer.sample_batch( self.batch_size, concatenate_samples, sample_attrs=["state", "action", "reward", "next_state", "terminal", "*"], ) with t.no_grad(): next_q_value = self._criticize(next_state) target_next_q_value = self._criticize(next_state, True) target_next_q_value = target_next_q_value.gather( dim=1, index=t.max(next_q_value, dim=1)[1].unsqueeze(1) ) q_value = self._criticize(state) # gather requires long tensor, int32 is not accepted action_value = q_value.gather( dim=1, index=self.action_get_function(action).to( device=q_value.device, dtype=t.long ), ) # Generate value reference :math: `y_i`. y_i = self.reward_function( reward, self.discount, target_next_q_value, terminal, others ) value_loss = self.criterion(action_value, y_i.to(action_value.device)) value_loss = value_loss * t.from_numpy(is_weight).view([batch_size, 1]).to( value_loss.device ) value_loss = value_loss.mean() abs_error = ( t.sum(t.abs(action_value - y_i.to(action_value.device)), dim=1) .flatten() .detach() .cpu() .numpy() ) self.replay_buffer.update_priority(abs_error, index) if self.visualize: self.visualize_model(value_loss, "qnet", self.visualize_dir) if update_value: self.qnet.zero_grad() self._backward(value_loss) nn.utils.clip_grad_norm_(self.qnet.parameters(), self.grad_max) self.qnet_optim.step() # Update target Q network if update_target: if self.update_rate is not None: soft_update(self.qnet_target, self.qnet, self.update_rate) else: self._update_counter += 1 if self._update_counter % self.update_steps == 0: hard_update(self.qnet_target, self.qnet) self.qnet.eval() # use .item() to prevent memory leakage return value_loss.item()
[docs] @classmethod def generate_config(cls, config: Union[Dict[str, Any], Config]): config = DQN.generate_config(config) config["frame"] = "DQNPer" return config
[docs] @classmethod def init_from_config( cls, config: Union[Dict[str, Any], Config], model_device: Union[str, t.device] = "cpu", ): f_config = deepcopy(config["frame_config"]) models = assert_and_get_valid_models(f_config["models"]) model_args = f_config["model_args"] model_kwargs = f_config["model_kwargs"] models = [ m(*arg, **kwarg).to(model_device) for m, arg, kwarg in zip(models, model_args, model_kwargs) ] optimizer = assert_and_get_valid_optimizer(f_config["optimizer"]) criterion = assert_and_get_valid_criterion(f_config["criterion"])( *f_config["criterion_args"], **f_config["criterion_kwargs"] ) criterion.reduction = "none" lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler( f_config["lr_scheduler"] ) f_config["optimizer"] = optimizer f_config["criterion"] = criterion f_config["lr_scheduler"] = lr_scheduler frame = cls(*models, **f_config) return frame