Source code for machin.frame.algorithms.apex

from .dqn_per import *
from .ddpg_per import *
from ..buffers.prioritized_buffer_d import DistributedPrioritizedBuffer
from machin.parallel.server import PushPullModelServer
from machin.parallel.distributed import RpcGroup


[docs]class DQNApex(DQNPer): """ Massively parallel version of a Double DQN with prioritized replay. The pull function is invoked before using ``act_discrete``, ``act_discrete_with_noise`` and ``criticize``. The push function is invoked after ``update``. """ def __init__(self, qnet: Union[NeuralNetworkModule, nn.Module], qnet_target: Union[NeuralNetworkModule, nn.Module], optimizer: Callable, criterion: Callable, apex_group: RpcGroup, model_server: Tuple[PushPullModelServer], *_, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple] = (), lr_scheduler_kwargs: Tuple[Dict] = (), batch_size: int = 100, update_rate: float = 0.005, learning_rate: float = 0.001, discount: float = 0.99, replay_size: int = 500000, **__): """ See Also: :class:`.DQNPer` Note: Apex framework supports multiple workers(samplers), and only one trainer, you may use ``DistributedDataParallel`` in trainer. If you use ``DistributedDataParallel``, you must call ``update()`` in all member processes of ``DistributedDataParallel``. Args: qnet: Q network module. qnet_target: Target Q network module. optimizer: Optimizer used to optimize ``qnet``. criterion: Criterion used to evaluate the value loss. apex_group: Group of all processes using the apex-DQN framework, including all samplers and trainers. model_server: Custom model sync server accessor for ``qnet``. lr_scheduler: Learning rate scheduler of ``optimizer``. lr_scheduler_args: Arguments of the learning rate scheduler. lr_scheduler_kwargs: Keyword arguments of the learning rate scheduler. batch_size: Batch size used during training. update_rate: :math:`\\tau` used to update target networks. Target parameters are updated as: :math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)` learning_rate: Learning rate of the optimizer, not compatible with ``lr_scheduler``. discount: :math:`\\gamma` used in the bellman function. replay_size: Local replay buffer size of a single worker. """ super(DQNApex, self).__init__(qnet, qnet_target, optimizer, criterion, lr_scheduler=lr_scheduler, lr_scheduler_args=lr_scheduler_args, lr_scheduler_kwargs=lr_scheduler_kwargs, batch_size=batch_size, update_rate=update_rate, learning_rate=learning_rate, discount=discount) # will not support sharing rpc group, # use static buffer_name is ok here. self.replay_buffer = DistributedPrioritizedBuffer( buffer_name="buffer", group=apex_group, buffer_size=replay_size ) self.apex_group = apex_group self.qnet_model_server = model_server[0] self.is_syncing = True
[docs] def set_sync(self, is_syncing): self.is_syncing = is_syncing
[docs] def manual_sync(self): self.qnet_model_server.pull(self.qnet)
[docs] def act_discrete(self, state: Dict[str, Any], use_target: bool = False, **__): # DOC INHERITED if self.is_syncing and not use_target: self.qnet_model_server.pull(self.qnet) return super(DQNApex, self).act_discrete(state, use_target)
[docs] def act_discrete_with_noise(self, state: Dict[str, Any], use_target: bool = False, **__): # DOC INHERITED if self.is_syncing and not use_target: self.qnet_model_server.pull(self.qnet) return super(DQNApex, self).act_discrete_with_noise(state, use_target)
[docs] def update(self, update_value=True, update_target=True, concatenate_samples=True, **__): # DOC INHERITED result = super(DQNApex, self).update(update_value, update_target, concatenate_samples) if isinstance(self.qnet, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)): self.qnet_model_server.push(self.qnet.module, pull_on_fail=False) else: self.qnet_model_server.push(self.qnet) return result
[docs]class DDPGApex(DDPGPer): """ Massively parallel version of a DDPG with prioritized replay. The pull function is invoked before using ``act``, ``act_with_noise``, ``act_discrete``, ``act_discrete_with_noise`` and ``criticize``. The push function is invoked after ``update``. """ def __init__(self, actor: Union[NeuralNetworkModule, nn.Module], actor_target: Union[NeuralNetworkModule, nn.Module], critic: Union[NeuralNetworkModule, nn.Module], critic_target: Union[NeuralNetworkModule, nn.Module], optimizer: Callable, criterion: Callable, apex_group: RpcGroup, model_server: Tuple[PushPullModelServer], *_, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple, Tuple] = (), lr_scheduler_kwargs: Tuple[Dict, Dict] = (), batch_size: int = 100, update_rate: float = 0.005, learning_rate: float = 0.001, discount: float = 0.99, replay_size: int = 500000, **__): """ See Also: :class:`.DDPGPer` TODO: implement truncated n-step returns, just like the one used in :class:`.RAINBOW`. Hint: Your push and pull function will be called like:: function(actor_model, "actor") The default implementation of pull and push functions is provided by :class:`.PushPullModelServer`. Args: actor: Actor network module. actor_target: Target actor network module. critic: Critic network module. critic_target: Target critic network module. optimizer: Optimizer used to optimize ``qnet``. criterion: Criterion used to evaluate the value loss. apex_group: Group of all processes using the apex-DDPG framework, including all samplers and trainers. model_server: Custom model sync server accessor for ``actor``. lr_scheduler: Learning rate scheduler of ``optimizer``. lr_scheduler_args: Arguments of the learning rate scheduler. lr_scheduler_kwargs: Keyword arguments of the learning rate scheduler. batch_size: Batch size used during training. update_rate: :math:`\\tau` used to update target networks. Target parameters are updated as: :math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)` learning_rate: Learning rate of the optimizer, not compatible with ``lr_scheduler``. discount: :math:`\\gamma` used in the bellman function. replay_size: Local replay buffer size of a single worker. """ super(DDPGApex, self).__init__(actor, actor_target, critic, critic_target, optimizer, criterion, lr_scheduler=lr_scheduler, lr_scheduler_args=lr_scheduler_args, lr_scheduler_kwargs=lr_scheduler_kwargs, batch_size=batch_size, update_rate=update_rate, learning_rate=learning_rate, discount=discount) # will not support sharing rpc group, # use static buffer_name is ok here. self.replay_buffer = DistributedPrioritizedBuffer( buffer_name="buffer", group=apex_group, buffer_size=replay_size ) self.apex_group = apex_group self.actor_model_server = model_server[0] self.is_syncing = True
[docs] def set_sync(self, is_syncing): self.is_syncing = is_syncing
[docs] def manual_sync(self): self.actor_model_server.pull(self.actor)
[docs] def act(self, state: Dict[str, Any], use_target: bool = False, **__): # DOC INHERITED if self.is_syncing and not use_target: self.actor_model_server.pull(self.actor) return super(DDPGApex, self).act(state, use_target)
[docs] def act_with_noise(self, state: Dict[str, Any], noise_param: Tuple = (0.0, 1.0), ratio: float = 1.0, mode: str = "uniform", use_target: bool = False, **__): # DOC INHERITED if self.is_syncing and not use_target: self.actor_model_server.pull(self.actor) return super(DDPGApex, self).act_with_noise(state, noise_param=noise_param, ratio=ratio, mode=mode, use_target=use_target)
[docs] def act_discrete(self, state: Dict[str, Any], use_target: bool = False, **__): # DOC INHERITED if self.is_syncing and not use_target: self.actor_model_server.pull(self.actor) return super(DDPGApex, self).act_discrete(state, use_target)
[docs] def act_discrete_with_noise(self, state: Dict[str, Any], use_target: bool = False, **__): # DOC INHERITED if self.is_syncing and not use_target: self.actor_model_server.pull(self.actor) return super(DDPGApex, self).act_discrete_with_noise(state, use_target)
[docs] def update(self, update_value=True, update_policy=True, update_target=True, concatenate_samples=True, **__): # DOC INHERITED result = super(DDPGApex, self).update(update_value, update_policy, update_target, concatenate_samples) if isinstance(self.actor, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)): self.actor_model_server.push(self.actor.module, pull_on_fail=False) else: self.actor_model_server.push(self.actor) return result