Source code for machin.frame.algorithms.a3c

from .a2c import *
from machin.parallel.server import PushPullGradServer
from machin.frame.helpers.servers import grad_server_helper
from .utils import FakeOptimizer, assert_and_get_valid_lr_scheduler


[docs]class A3C(A2C): """ A3C framework. """ def __init__( self, actor: Union[NeuralNetworkModule, nn.Module], critic: Union[NeuralNetworkModule, nn.Module], criterion: Callable, grad_server: Tuple[PushPullGradServer, PushPullGradServer], *_, batch_size: int = 100, actor_update_times: int = 5, critic_update_times: int = 10, entropy_weight: float = None, value_weight: float = 0.5, gradient_max: float = np.inf, gae_lambda: float = 1.0, discount: float = 0.99, normalize_advantage: bool = True, replay_size: int = 500000, replay_device: Union[str, t.device] = "cpu", replay_buffer: Buffer = None, visualize: bool = False, visualize_dir: str = "", **__ ): """ See Also: :class:`.A2C` Note: A3C algorithm relies on parameter servers to synchronize parameters of actor and critic models across samplers ( interact with environment) and trainers (using samples to train. The parameter server type :class:`.PushPullGradServer` used here utilizes gradients calculated by trainers: 1. perform a "sum" reduction process on the collected gradients, then apply this reduced gradient to the model managed by its primary reducer 2. push the parameters of this updated managed model to a ordered key-value server so that all processes, including samplers and trainers, can access the updated parameters. The ``grad_servers`` argument is a pair of accessors to two :class:`.PushPullGradServerImpl` class. Args: actor: Actor network module. critic: Critic network module. criterion: Criterion used to evaluate the value loss. grad_server: Custom gradient sync server accessors, the first server accessor is for actor, and the second one is for critic. batch_size: Batch size used during training. actor_update_times: Times to update actor in ``update()``. critic_update_times: Times to update critic in ``update()``. entropy_weight: Weight of entropy in your loss function, a positive entropy weight will minimize entropy, while a negative one will maximize entropy. value_weight: Weight of critic value loss. gradient_max: Maximum gradient. gae_lambda: :math:`\\lambda` used in generalized advantage estimation. discount: :math:`\\gamma` used in the bellman function. normalize_advantage: Whether to normalize sampled advantage values in the batch. replay_size: Replay buffer size. Not compatible with ``replay_buffer``. replay_device: Device where the replay buffer locates on, Not compatible with ``replay_buffer``. replay_buffer: Custom replay buffer. visualize: Whether visualize the network flow in the first pass. visualize_dir: Visualized graph save directory. """ # Adam is just a placeholder here, the actual optimizer is # set in parameter servers super().__init__( actor, critic, FakeOptimizer, criterion, batch_size=batch_size, actor_update_times=actor_update_times, critic_update_times=critic_update_times, entropy_weight=entropy_weight, value_weight=value_weight, gradient_max=gradient_max, gae_lambda=gae_lambda, discount=discount, normalize_advantage=normalize_advantage, replay_size=replay_size, replay_device=replay_device, replay_buffer=replay_buffer, visualize=visualize, visualize_dir=visualize_dir, ) # disable local stepping self.actor_optim.step = lambda: None self.critic_optim.step = lambda: None self.actor_grad_server, self.critic_grad_server = grad_server[0], grad_server[1] self.is_syncing = True @property def optimizers(self): return [] @optimizers.setter def optimizers(self, optimizers): pass @property def lr_schedulers(self): return []
[docs] @classmethod def is_distributed(cls): return True
[docs] def set_sync(self, is_syncing): self.is_syncing = is_syncing
[docs] def manual_sync(self): self.actor_grad_server.pull(self.actor) self.critic_grad_server.pull(self.critic)
[docs] def act(self, state: Dict[str, Any], **__): # DOC INHERITED if self.is_syncing: self.actor_grad_server.pull(self.actor) return super().act(state)
def _eval_act(self, state: Dict[str, Any], action: Dict[str, Any], **__): # DOC INHERITED if self.is_syncing: self.actor_grad_server.pull(self.actor) return super()._eval_act(state, action) def _criticize(self, state: Dict[str, Any], *_, **__): # DOC INHERITED if self.is_syncing: self.critic_grad_server.pull(self.critic) return super()._criticize(state)
[docs] def update( self, update_value=True, update_policy=True, concatenate_samples=True, **__ ): # DOC INHERITED org_sync = self.is_syncing self.is_syncing = False super().update(update_value, update_policy, concatenate_samples) self.is_syncing = org_sync self.actor_grad_server.push(self.actor) self.critic_grad_server.push(self.critic)
[docs] @classmethod def generate_config(cls, config: Union[Dict[str, Any], Config]): default_values = { "grad_server_group_name": "a3c_grad_server", "grad_server_members": "all", "models": ["Actor", "Critic"], "model_args": ((), ()), "model_kwargs": ({}, {}), "optimizer": "Adam", "criterion": "MSELoss", "criterion_args": (), "criterion_kwargs": {}, "lr_scheduler": None, "lr_scheduler_args": None, "lr_scheduler_kwargs": None, "batch_size": 100, "actor_update_times": 5, "critic_update_times": 10, "actor_learning_rate": 0.001, "critic_learning_rate": 0.001, "entropy_weight": None, "value_weight": 0.5, "gradient_max": np.inf, "gae_lambda": 1.0, "discount": 0.99, "normalize_advantage": True, "replay_size": 500000, "replay_device": "cpu", "replay_buffer": None, "visualize": False, "visualize_dir": "", } config = deepcopy(config) config["frame"] = "A3C" if "frame_config" not in config: config["frame_config"] = default_values else: config["frame_config"] = {**config["frame_config"], **default_values} return config
[docs] @classmethod def init_from_config( cls, config: Union[Dict[str, Any], Config], model_device: Union[str, t.device] = "cpu", ): f_config = deepcopy(config["frame_config"]) model_cls = assert_and_get_valid_models(f_config["models"]) model_args = f_config["model_args"] model_kwargs = f_config["model_kwargs"] models = [ m(*arg, **kwarg).to(model_device) for m, arg, kwarg in zip(model_cls, model_args, model_kwargs) ] model_creators = [ lambda: m(*arg, **kwarg) for m, arg, kwarg in zip(model_cls, model_args, model_kwargs) ] optimizer = assert_and_get_valid_optimizer(f_config["optimizer"]) criterion = assert_and_get_valid_criterion(f_config["criterion"])( *f_config["criterion_args"], **f_config["criterion_kwargs"] ) lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler( f_config["lr_scheduler"] ) servers = grad_server_helper( model_creators, group_name=f_config["grad_server_group_name"], members=f_config["grad_server_members"], optimizer=optimizer, learning_rate=[ f_config["actor_learning_rate"], f_config["critic_learning_rate"], ], lr_scheduler=lr_scheduler, lr_scheduler_args=f_config["lr_scheduler_args"] or ((), ()), lr_scheduler_kwargs=f_config["lr_scheduler_kwargs"] or ({}, {}), ) del f_config["criterion"] frame = cls(*models, criterion, servers, **f_config) return frame