Source code for machin.frame.algorithms.a3c

from .a2c import *
from machin.parallel.server import PushPullGradServer
from torch.optim import Adam


[docs]class A3C(A2C): """ A3C framework. """ def __init__(self, actor: Union[NeuralNetworkModule, nn.Module], critic: Union[NeuralNetworkModule, nn.Module], criterion: Callable, grad_server: Tuple[PushPullGradServer, PushPullGradServer], *_, entropy_weight: float = None, value_weight: float = 0.5, gradient_max: float = np.inf, gae_lambda: float = 1.0, discount: float = 0.99, update_times: int = 50, replay_size: int = 500000, replay_device: Union[str, t.device] = "cpu", replay_buffer: Buffer = None, visualize: bool = False, **__): """ See Also: :class:`.A2C` Note: A3C algorithm relies on parameter servers to synchronize parameters of actor and critic models across samplers ( interact with environment) and trainers (using samples to train. The parameter server type :class:`.PushPullGradServer` used here utilizes gradients calculated by trainers: 1. perform a "sum" reduction process on the collected gradients, then apply this reduced gradient to the model managed by its primary reducer 2. push the parameters of this updated managed model to a ordered key-value server so that all processes, including samplers and trainers, can access the updated parameters. The ``grad_servers`` argument is a pair of accessors to two :class:`.PushPullGradServerImpl` class. Args: actor: Actor network module. critic: Critic network module. optimizer: Optimizer used to optimize ``actor`` and ``critic``. criterion: Criterion used to evaluate the value loss. grad_server: Custom gradient sync server accessors, the first server accessor is for actor, and the second one is for critic. entropy_weight: Weight of entropy in your loss function, a positive entropy weight will minimize entropy, while a negative one will maximize entropy. value_weight: Weight of critic value loss. gradient_max: Maximum gradient. gae_lambda: :math:`\\lambda` used in generalized advantage estimation. discount: :math:`\\gamma` used in the bellman function. update_times: Number of update iterations per sample period. Buffer will be cleared after ``update()`` replay_size: Replay buffer size. Not compatible with ``replay_buffer``. replay_device: Device where the replay buffer locates on, Not compatible with ``replay_buffer``. replay_buffer: Custom replay buffer. visualize: Whether visualize the network flow in the first pass. """ # Adam is just a placeholder here, the actual optimizer is # set in parameter servers super(A3C, self).__init__(actor, critic, Adam, criterion, entropy_weight=entropy_weight, value_weight=value_weight, gradient_max=gradient_max, gae_lambda=gae_lambda, discount=discount, update_times=update_times, replay_size=replay_size, replay_device=replay_device, replay_buffer=replay_buffer, visualize=visualize) # disable local stepping self.actor_optim.step = lambda: None self.critic_optim.step = lambda: None self.actor_grad_server, self.critic_grad_server = \ grad_server[0], grad_server[1] self.is_syncing = True
[docs] def set_sync(self, is_syncing): self.is_syncing = is_syncing
[docs] def manual_sync(self): self.actor_grad_server.pull(self.actor) self.critic_grad_server.pull(self.critic)
[docs] def act(self, state: Dict[str, Any], **__): # DOC INHERITED if self.is_syncing: self.actor_grad_server.pull(self.actor) return super(A3C, self).act(state)
def _eval_act(self, state: Dict[str, Any], action: Dict[str, Any], **__): # DOC INHERITED if self.is_syncing: self.actor_grad_server.pull(self.actor) return super(A3C, self)._eval_act(state, action) def _criticize(self, state: Dict[str, Any], *_, **__): # DOC INHERITED if self.is_syncing: self.critic_grad_server.pull(self.critic) return super(A3C, self)._criticize(state)
[docs] def update(self, update_value=True, update_policy=True, concatenate_samples=True, **__): # DOC INHERITED org_sync = self.is_syncing self.is_syncing = False super(A3C, self).update(update_value, update_policy, concatenate_samples) self.is_syncing = org_sync self.actor_grad_server.push(self.actor) self.critic_grad_server.push(self.critic)