Source code for machin.frame.algorithms.impala

from typing import Union, Dict, List, Tuple, Callable, Any
from copy import deepcopy
import random
import numpy as np
import torch as t
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from machin.frame.buffers.buffer_d import DistributedBuffer
from machin.frame.transition import Transition
from machin.frame.helpers.servers import model_server_helper
from machin.model.nets.base import NeuralNetworkModule
from machin.parallel.server import PushPullModelServer
from machin.parallel.distributed import RpcGroup, get_world
from .base import TorchFramework, Config
from .utils import (
    safe_call,
    assert_and_get_valid_models,
    assert_and_get_valid_optimizer,
    assert_and_get_valid_criterion,
    assert_and_get_valid_lr_scheduler,
)


def _disable_update(*_, **__):
    return None, None


[docs]class IMPALABuffer(DistributedBuffer): """ Samples full episodes for batch_size instead of steps. """
[docs] def sample_batch( self, batch_size: int, concatenate: bool = True, device: Union[str, t.device] = "cpu", sample_attrs: List[str] = None, additional_concat_custom_attrs: List[str] = None, *_, **__, ) -> Any: return super().sample_batch( batch_size=batch_size, concatenate=concatenate, device=device, sample_method="episode", sample_attrs=sample_attrs, additional_concat_custom_attrs=additional_concat_custom_attrs, )
[docs] def sample_method_episode(self, batch_size: int) -> Tuple[int, List[Transition]]: """ Args: batch_size: Number of **episodes** to sample. """ batch_size = min(len(self.episode_transition_handles), batch_size) episodes = random.choices( list(self.episode_transition_handles.keys()), k=batch_size ) batch = [ self.storage[handle] for episode in episodes for handle in self.episode_transition_handles[episode] ] return batch_size, batch
[docs]class IMPALA(TorchFramework): """ Massively parallel IMPALA framework. """ _is_top = ["actor", "critic"] _is_restorable = ["actor", "critic"] def __init__( self, actor: Union[NeuralNetworkModule, nn.Module], critic: Union[NeuralNetworkModule, nn.Module], optimizer: Callable, criterion: Callable, impala_group: RpcGroup, model_server: Tuple[PushPullModelServer], *_, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple, Tuple] = (), lr_scheduler_kwargs: Tuple[Dict, Dict] = (), batch_size: int = 5, learning_rate: float = 0.001, isw_clip_c: float = 1.0, isw_clip_rho: float = 1.0, entropy_weight: float = None, value_weight: float = 0.5, gradient_max: float = np.inf, discount: float = 0.99, replay_size: int = 500, **__, ): """ Note: Please make sure isw_clip_rho >= isw_clip_c Args: actor: Actor network module. critic: Critic network module. optimizer: Optimizer used to optimize ``actor`` and ``critic``. criterion: Criterion used to evaluate the value loss. impala_group: Group of all processes using the IMPALA framework, including all samplers and trainers. model_server: Custom model sync server accessor for ``actor``. lr_scheduler: Learning rate scheduler of ``optimizer``. lr_scheduler_args: Arguments of the learning rate scheduler. lr_scheduler_kwargs: Keyword arguments of the learning rate scheduler. batch_size: Batch size used during training. learning_rate: Learning rate of the optimizer, not compatible with ``lr_scheduler``. isw_clip_c: :math:`c` used in importance weight clipping. isw_clip_rho: entropy_weight: Weight of entropy in your loss function, a positive entropy weight will minimize entropy, while a negative one will maximize entropy. value_weight: Weight of critic value loss. gradient_max: Maximum gradient. discount: :math:`\\gamma` used in the bellman function. replay_size: Size of the local replay buffer. """ self.batch_size = batch_size self.discount = discount self.value_weight = value_weight self.entropy_weight = entropy_weight self.grad_max = gradient_max self.isw_clip_c = isw_clip_c self.isw_clip_rho = isw_clip_rho self.impala_group = impala_group self.actor = actor self.critic = critic self.actor_optim = optimizer(self.actor.parameters(), lr=learning_rate) self.critic_optim = optimizer(self.critic.parameters(), lr=learning_rate) self.replay_buffer = IMPALABuffer( buffer_name="buffer", group=impala_group, buffer_size=replay_size ) self.is_syncing = True self.actor_model_server = model_server[0] if lr_scheduler is not None: self.actor_lr_sch = lr_scheduler( self.actor_optim, *lr_scheduler_args[0], **lr_scheduler_kwargs[0], ) self.critic_lr_sch = lr_scheduler( self.critic_optim, *lr_scheduler_args[1], **lr_scheduler_kwargs[1] ) self.criterion = criterion self._is_using_DP_or_DDP = isinstance( self.actor, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) ) super().__init__() @property def optimizers(self): return [self.actor_optim, self.critic_optim] @optimizers.setter def optimizers(self, optimizers): self.actor_optim, self.critic_optim = optimizers @property def lr_schedulers(self): if hasattr(self, "actor_lr_sch") and hasattr(self, "critic_lr_sch"): return [self.actor_lr_sch, self.critic_lr_sch] return []
[docs] @classmethod def is_distributed(cls): return True
[docs] def set_sync(self, is_syncing): self.is_syncing = is_syncing
[docs] def manual_sync(self): if not self._is_using_DP_or_DDP: self.actor_model_server.pull(self.actor)
[docs] def act(self, state: Dict[str, Any], *_, **__): """ Use actor network to give a policy to the current state. Returns: Anything produced by actor. """ if self.is_syncing and not self._is_using_DP_or_DDP: self.actor_model_server.pull(self.actor) return safe_call(self.actor, state)
def _eval_act(self, state: Dict[str, Any], action: Dict[str, Any], *_, **__): """ Use actor network to evaluate the log-likelihood of a given action in the current state. Returns: Anything produced by actor. """ return safe_call(self.actor, state, action) def _criticize(self, state: Dict[str, Any], *_, **__): """ Use critic network to evaluate current value. Returns: Value of shape ``[batch_size, 1]`` """ return safe_call(self.critic, state)[0]
[docs] def store_episode(self, episode: List[Union[Transition, Dict]]): """ Add a full episode of transition samples to the replay buffer. """ if len(episode) == 0: raise ValueError("Episode must be non-empty.") # The first step records episode length, other steps records 0 episode[0]["episode_length"] = len(episode) for transition in episode[1:]: transition["episode_length"] = 0 self.replay_buffer.store_episode( episode, required_attrs=( "state", "action", "next_state", "reward", "action_log_prob", "terminal", "episode_length", ), )
[docs] def update(self, update_value=True, update_policy=True, **__): """ Update network weights by sampling from replay buffer. Note: Will always concatenate samples. Args: update_value: Whether to update the Q network. update_policy: Whether to update the actor network. Returns: mean value of estimated policy value, value loss """ # sample a batch # `batch_size` here means the number of episodes sampled, not # the number of steps sampled. # The size of the batch dimension of sampled attributes should be # the summed length of sampled episodes, # eg: total_length = ep1_length + ep2_length + ... self.actor.train() self.critic.train() ( batch_size, ( state, action, reward, next_state, terminal, action_log_prob, episode_length, ), ) = self.replay_buffer.sample_batch( self.batch_size, device="cpu", sample_attrs=[ "state", "action", "reward", "next_state", "terminal", "action_log_prob", "episode_length", ], additional_concat_custom_attrs=["action_log_prob"], ) # episodes are chained together like: # ep1_step1, ep1_step2, ..., ep1_stepN, ep2_step1, ep2_step2 ... # `state`, `action` and `next_state` should be dicts like: # {"attr1": Tensor(total_length, ...), # "attr2": Tensor(total_length, ...)} # `terminal`, `reward`, `action_log_prob` should be tensors like: # Tensor(total_length, 1) # store the length of each episode, so that we can find boundaries # between two episodes inside the chained "sample" all_length = [length for length in episode_length if length != 0] sum_length = sum(all_length) if sum_length != terminal.shape[0]: raise RuntimeError( "Sum length is unequal to tensor total length," " an unknown error has occurred." ) # Below are the v-trace process # Calculate c and rho first, because there is no dependency # between vector elements. _, cur_action_log_prob, entropy, *__ = self._eval_act(state, action) cur_action_log_prob = cur_action_log_prob.view(sum_length, 1).to("cpu") entropy = entropy.view(sum_length, 1).to("cpu") # similarity = pi(a_t|x_t)/mu(a_t|x_t) sim = t.exp(cur_action_log_prob - action_log_prob) c = t.min(t.full(sim.shape, self.isw_clip_c, dtype=sim.dtype), sim) rho = t.min(t.full(sim.shape, self.isw_clip_rho, dtype=sim.dtype), sim) # calculate delta V # delta_t V = rho_t(r_t + gamma * V(x_{t+1}) - V(x_t)) # boundary elements (i.e, ep1_stepN) will have V(x_{t+1}) = 0 value = self._criticize(state).view(sum_length, 1).to("cpu") next_value = self._criticize(next_state).view(sum_length, 1).to("cpu") next_value[terminal] = 0 delta_v = rho * (reward + self.discount * next_value - value) # calculate v_s # vss is v_s shifted left by 1 element, i.e. becomes v_{s+1} # boundary elements (i.e, ep1_stepN) will have v_{s+1} = 0 # do reversed cumulative product for each episode segment with t.no_grad(): vs = t.zeros_like(value) vss = t.zeros_like(value) offset = 0 for ep_len in all_length: # the last v_s of each episode should be 0 # or V_t - rho_t * (r_t - V_t)? (since Vt+1 = 0) # Implementations such as # https://github.com/junjungoal/IMPALA-pytorch/blob/master # /agents/learner.py use the first case, 0 # 0 works well when rho=c=1 or 1 > rho >= c for rev_step in reversed(range(ep_len - 1)): idx = offset + rev_step vs[idx] = ( value[idx] + delta_v[idx] + self.discount * c[idx] * (vs[idx + 1] - value[idx + 1]) ) # shift v_s to get v_{s+1} vss[offset : offset + ep_len - 1] = vs[offset + 1 : offset + ep_len] # update offset offset += ep_len act_policy_loss = -( rho.detach() * cur_action_log_prob * (reward + self.discount * vss - value).detach() ) if self.entropy_weight is not None: act_policy_loss += self.entropy_weight * entropy act_policy_loss = act_policy_loss.sum() value_loss = self.criterion(value, vs.type_as(value)) # Update actor network if update_policy: self.actor.zero_grad() self._backward(act_policy_loss) nn.utils.clip_grad_norm_(self.actor.parameters(), self.grad_max) self.actor_optim.step() # Update critic network if update_value: self.critic.zero_grad() self._backward(value_loss) nn.utils.clip_grad_norm_(self.critic.parameters(), self.grad_max) self.critic_optim.step() # push actor model for samplers if self._is_using_DP_or_DDP: self.actor_model_server.push(self.actor.module, pull_on_fail=False) else: self.actor_model_server.push(self.actor) self.actor.eval() self.critic.eval() return -act_policy_loss.item(), value_loss.item()
[docs] def update_lr_scheduler(self): """ Update learning rate schedulers. """ if hasattr(self, "actor_lr_sch"): self.actor_lr_sch.step() if hasattr(self, "critic_lr_sch"): self.critic_lr_sch.step()
[docs] @classmethod def generate_config(cls, config: Union[Dict[str, Any], Config]): default_values = { "learner_process_number": 1, "model_server_group_name": "impala_model_server", "model_server_members": "all", "impala_group_name": "impala", "impala_members": "all", "models": ["Actor", "Critic"], "model_args": ((), ()), "model_kwargs": ({}, {}), "optimizer": "Adam", "criterion": "MSELoss", "criterion_args": (), "criterion_kwargs": {}, "lr_scheduler": None, "lr_scheduler_args": None, "lr_scheduler_kwargs": None, "batch_size": 5, "learning_rate": 0.001, "isw_clip_c": 1.0, "isw_clip_rho": 1.0, "entropy_weight": None, "value_weight": 0.5, "gradient_max": np.inf, "discount": 0.99, "replay_size": 500, } config = deepcopy(config) config["frame"] = "IMPALA" config["batch_num"] = {"sampler": 10, "learner": 1} if "frame_config" not in config: config["frame_config"] = default_values else: config["frame_config"] = {**config["frame_config"], **default_values} return config
[docs] @classmethod def init_from_config( cls, config: Union[Dict[str, Any], Config], model_device: Union[str, t.device] = "cpu", ): world = get_world() f_config = deepcopy(config["frame_config"]) impala_group = world.create_rpc_group( group_name=f_config["impala_group_name"], members=( world.get_members() if f_config["impala_members"] == "all" else f_config["impala_members"] ), ) models = assert_and_get_valid_models(f_config["models"]) model_args = f_config["model_args"] model_kwargs = f_config["model_kwargs"] models = [ m(*arg, **kwarg).to(model_device) for m, arg, kwarg in zip(models, model_args, model_kwargs) ] # wrap models in DistributedDataParallel when running in learner mode max_learner_id = f_config["learner_process_number"] learner_group = world.create_collective_group(ranks=list(range(max_learner_id))) if world.rank < max_learner_id: models = [ DistributedDataParallel(module=m, process_group=learner_group.group) for m in models ] optimizer = assert_and_get_valid_optimizer(f_config["optimizer"]) criterion = assert_and_get_valid_criterion(f_config["criterion"])( *f_config["criterion_args"], **f_config["criterion_kwargs"] ) lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler( f_config["lr_scheduler"] ) servers = model_server_helper( model_num=1, group_name=f_config["model_server_group_name"], members=f_config["model_server_members"], ) del f_config["optimizer"] del f_config["criterion"] del f_config["lr_scheduler"] frame = cls( *models, optimizer, criterion, impala_group, servers, lr_scheduler=lr_scheduler, **f_config, ) if world.rank >= max_learner_id: frame.role = "sampler" frame.update = _disable_update else: frame.role = "learner" return frame