Source code for machin.frame.algorithms.impala

from typing import Union, Dict, List, Tuple, Callable, Any
import numpy as np
import torch as t
import torch.nn as nn

from machin.frame.buffers.buffer_d import Transition, DistributedBuffer
from machin.model.nets.base import NeuralNetworkModule
from .base import TorchFramework
from .utils import safe_call

from machin.parallel.server import PushPullModelServer
from machin.parallel.distributed import RpcGroup


def _make_tensor_from_batch(batch: List[Any], device, concatenate):
    """
    Used to convert compact every attribute of every step of a whole episode
    into a single tensor.

    Args:
        batch: A list of tensor or scalar. If elements are tensors, they will
            be concatenated in dimension 0.
        device: Device to move tensors to.
        concatenate: Whether to perform concatenation or not, only True
            for major and sub attributes in ``Transition``.

    Returns:
        A tensor if ``concatenate`` is ``True``, otherwise the original List.
    """
    if len(batch) == 0:
        return None
    if concatenate:
        item = batch[0]
        batch_size = len(batch)
        if t.is_tensor(item):
            return t.cat([it.to(device) for it in batch], dim=0).to(device)
        else:
            return t.tensor(batch, device=device).view(batch_size, -1)
    else:
        return batch


[docs]class EpisodeTransition(Transition): """ A transition class which allows storing the whole episode as a single transition object, the batch dimension will be used to stack all transition steps. """ def _check_validity(self): """ Disable checking for batch size in the base :class:`.Transition` """ super(Transition, self)._check_validity()
[docs]class EpisodeDistributedBuffer(DistributedBuffer): """ A distributed buffer which stores each episode as a transition object inside the buffer. """
[docs] def append(self, transition: Dict, required_attrs=("state", "action", "next_state", "reward", "terminal", "action_log_prob")): transition = EpisodeTransition(**transition) super(EpisodeDistributedBuffer, self)\ .append(transition, required_attrs=required_attrs)
[docs]class IMPALA(TorchFramework): """ Massively parallel IMPALA framework. """ _is_top = ["actor", "critic"] _is_restorable = ["actor", "critic"] def __init__(self, actor: Union[NeuralNetworkModule, nn.Module], critic: Union[NeuralNetworkModule, nn.Module], optimizer: Callable, criterion: Callable, impala_group: RpcGroup, model_server: Tuple[PushPullModelServer], *_, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple, Tuple] = (), lr_scheduler_kwargs: Tuple[Dict, Dict] = (), batch_size: int = 5, learning_rate: float = 0.001, isw_clip_c: float = 1.0, isw_clip_rho: float = 1.0, entropy_weight: float = None, value_weight: float = 0.5, gradient_max: float = np.inf, discount: float = 0.99, replay_size: int = 500, visualize: bool = False, **__): """ Note: Please make sure isw_clip_rho >= isw_clip_c Args: actor: Actor network module. critic: Critic network module. optimizer: Optimizer used to optimize ``actor`` and ``critic``. criterion: Criterion used to evaluate the value loss. impala_group: Group of all processes using the IMPALA framework, including all samplers and trainers. model_server: Custom model sync server accessor for ``actor``. lr_scheduler: Learning rate scheduler of ``optimizer``. lr_scheduler_args: Arguments of the learning rate scheduler. lr_scheduler_kwargs: Keyword arguments of the learning rate scheduler. batch_size: Batch size used during training. learning_rate: Learning rate of the optimizer, not compatible with ``lr_scheduler``. isw_clip_c: :math:`c` used in importance weight clipping. isw_clip_rho: entropy_weight: Weight of entropy in your loss function, a positive entropy weight will minimize entropy, while a negative one will maximize entropy. value_weight: Weight of critic value loss. gradient_max: Maximum gradient. discount: :math:`\\gamma` used in the bellman function. replay_size: Size of the local replay buffer. visualize: Whether visualize the network flow in the first pass. """ self.batch_size = batch_size self.discount = discount self.value_weight = value_weight self.entropy_weight = entropy_weight self.grad_max = gradient_max self.isw_clip_c = isw_clip_c self.isw_clip_rho = isw_clip_rho self.visualize = visualize self.impala_group = impala_group self.actor = actor self.critic = critic self.actor_optim = optimizer(self.actor.parameters(), lr=learning_rate) self.critic_optim = optimizer(self.critic.parameters(), lr=learning_rate) self.replay_buffer = EpisodeDistributedBuffer( buffer_name="buffer", group=impala_group, buffer_size=replay_size ) self.is_syncing = True self.actor_model_server = model_server[0] if lr_scheduler is not None: self.actor_lr_sch = lr_scheduler( self.actor_optim, *lr_scheduler_args[0], **lr_scheduler_kwargs[0], ) self.critic_lr_sch = lr_scheduler( self.critic_optim, *lr_scheduler_args[1], **lr_scheduler_kwargs[1] ) self.criterion = criterion super(IMPALA, self).__init__()
[docs] def set_sync(self, is_syncing): self.is_syncing = is_syncing
[docs] def manual_sync(self): self.actor_model_server.pull(self.actor)
[docs] def act(self, state: Dict[str, Any], *_, **__): """ Use actor network to give a policy to the current state. Returns: Anything produced by actor. """ if self.is_syncing: self.actor_model_server.pull(self.actor) return safe_call(self.actor, state)
def _eval_act(self, state: Dict[str, Any], action: Dict[str, Any], *_, **__): """ Use actor network to evaluate the log-likelihood of a given action in the current state. Returns: Anything produced by actor. """ return safe_call(self.actor, state, action) def _criticize(self, state: Dict[str, Any], *_, **__): """ Use critic network to evaluate current value. Returns: Value of shape ``[batch_size, 1]`` """ return safe_call(self.critic, state)[0]
[docs] def store_transition(self, transition: Union[Transition, Dict]): """ Warning: Not supported in IMPALA due to v-trace requirements. """ raise NotImplementedError
[docs] def store_episode(self, episode: List[Union[Transition, Dict]]): """ Add a full episode of transition samples to the replay buffer. """ if not isinstance(episode[0], Transition): episode = [Transition(**trans) for trans in episode] cc_episode = {} # In order to compute v-trace, we must reshape the whole # episode to make it look like a single Transition, because # v-trace need to see all future rewards. # therefore, only one entry will be stored into the buffer # each entry in the buffer is of shape [episode_length, ...] # In other frameworks. each entry in the buffer is of shape # [1, ...] for k, v in episode[0].items(): if k in ("state", "action", "next_state"): tmp_dict = {} for sub_k in v.keys(): tmp_dict[sub_k] = _make_tensor_from_batch( [item[k][sub_k] for item in episode], self.replay_buffer.buffer_device, True ) cc_episode[k] = tmp_dict elif k in ("reward", "terminal", "action_log_prob"): cc_episode[k] = _make_tensor_from_batch( [item[k] for item in episode], self.replay_buffer.buffer_device, True ) else: # currently, additional attributes are not supported. pass self.replay_buffer.append(cc_episode, required_attrs=( "state", "action", "next_state", "reward", "action_log_prob", "terminal" ))
[docs] def update(self, update_value=True, update_policy=True, **__): """ Update network weights by sampling from replay buffer. Note: Will always concatenate samples. Args: update_value: Whether to update the Q network. update_policy: Whether to update the actor network. Returns: mean value of estimated policy value, value loss """ # sample a batch # Note: each episode is stored as a single sample entry, # the second dimension of all attributes is the length of episode, # the first dimension is always 1. # `batch_size` here means the number of episodes sampled, not # the number of steps sampled. # `concatenate` is False, because the length of each episode # might be different. self.actor.train() self.critic.train() batch_size, (state, action, reward, next_state, terminal, action_log_prob) = \ self.replay_buffer.sample_batch(self.batch_size, concatenate=False, device="cpu", sample_attrs=[ "state", "action", "reward", "next_state", "terminal", "action_log_prob"], additional_concat_attrs=[ "action_log_prob" ]) # `state`, `action` and `next_state` should be dicts like: # {"attr1": [Tensor(ep1_length, ...), # Tensor(ep2_length, ...)]} # `terminal`, `reward`, `action_log_prob` should be lists like: # [Tensor(ep1_length, 1), (ep2_length, 1)] # chain steps of all episodes together, make them look like: # ep1_step1, ep1_step2, ..., ep1_stepN, ep2_step1, ep2_step2 ... # store the length of each episode, so that we can find boundaries # between two episodes inside the chained "sample" all_length = [tensor.shape[0] for tensor in terminal] sum_length = sum(all_length) for major_attr in (state, action, next_state): for k, v in major_attr.items(): major_attr[k] = t.cat(v, dim=0) assert major_attr[k].shape[0] == sum_length terminal = t.cat(terminal, dim=0).view(sum_length, 1) reward = t.cat(reward, dim=0).view(sum_length, 1) action_log_prob = t.cat(action_log_prob, dim=0).view(sum_length, 1) # Below are the v-trace process # Calculate c and rho first, because there is no dependency # between vector elements. _, cur_action_log_prob, entropy, *__ = self._eval_act(state, action) cur_action_log_prob = cur_action_log_prob.view(sum_length, 1).to("cpu") entropy = entropy.view(sum_length, 1).to("cpu") # similarity = pi(a_t|x_t)/mu(a_t|x_t) sim = t.exp(cur_action_log_prob - action_log_prob) c = t.min(t.full(sim.shape, self.isw_clip_c), sim) rho = t.min(t.full(sim.shape, self.isw_clip_rho), sim) # calculate delta V # delta_t V = rho_t(r_t + gamma * V(x_{t+1}) - V(x_t)) # boundary elements (i.e, ep1_stepN) will have V(x_{t+1}) = 0 value = self._criticize(state).view(sum_length, 1).to("cpu") next_value = self._criticize(next_state).view(sum_length, 1).to("cpu") next_value[terminal] = 0 delta_v = rho * (reward + self.discount * next_value - value) # calculate v_s # vss is v_s shifted left by 1 element, i.e. becomes v_{s+1} # boundary elements (i.e, ep1_stepN) will have v_{s+1} = 0 # do reversed cumulative product for each episode segment with t.no_grad(): vs = t.zeros(value.shape) vss = t.zeros(value.shape) offset = 0 for ep_len in all_length: # the last v_s of each episode should be 0 # or V_t - rho_t * (r_t - V_t)? (since Vt+1 = 0) # Implementations such as # https://github.com/junjungoal/IMPALA-pytorch/blob/master # /agents/learner.py use the first case, 0 # 0 works well when rho=c=1 or 1 > rho >= c for rev_step in reversed(range(ep_len - 1)): idx = offset + rev_step vs[idx] = (value[idx] + delta_v[idx] + self.discount * c[idx] * (vs[idx + 1] - value[idx + 1])) # shift v_s to get v_{s+1} vss[offset: offset + ep_len - 1] = \ vs[offset + 1: offset + ep_len] # update offset offset += ep_len act_policy_loss = -(rho.detach() * cur_action_log_prob * (reward + self.discount * vss - value).detach()) if self.entropy_weight is not None: act_policy_loss += self.entropy_weight * entropy act_policy_loss = act_policy_loss.sum() value_loss = self.criterion(value, vs.to(value.device)) # Update actor network if update_policy: self.actor.zero_grad() act_policy_loss.backward() nn.utils.clip_grad_norm_( self.actor.parameters(), self.grad_max ) self.actor_optim.step() # Update critic network if update_value: self.critic.zero_grad() value_loss.backward() nn.utils.clip_grad_norm_( self.critic.parameters(), self.grad_max ) self.critic_optim.step() # push actor model for samplers if isinstance(self.actor, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)): self.actor_model_server.push(self.actor.module, pull_on_fail=False) else: self.actor_model_server.push(self.actor) self.actor.eval() self.critic.eval() return -act_policy_loss.item(), value_loss.item()
[docs] def update_lr_scheduler(self): """ Update learning rate schedulers. """ if hasattr(self, "actor_lr_sch"): self.actor_lr_sch.step() if hasattr(self, "critic_lr_sch"): self.critic_lr_sch.step()