Source code for machin.frame.buffers.buffer_d

from typing import Union, Dict, List, Tuple, Callable
from threading import RLock
from machin.parallel.distributed import RpcGroup
from ..transition import TransitionBase
from .buffer import Buffer
from .storage import TransitionStorageBase, TransitionStorageBasic

import torch as t
import numpy as np
import itertools as it

def _round_up(num):
    return int(np.ceil(num))

[docs]class DistributedBuffer(Buffer): def __init__( self, buffer_name: str, group: RpcGroup, buffer_size: int = 1000000, storage: TransitionStorageBase = None, **kwargs, ): """ Create a distributed replay buffer instance. To avoid issues caused by tensor device difference, all transition objects are stored in device "cpu". Distributed replay buffer constitutes of many local buffers held per process, transmissions between processes only happen during sampling. During sampling, the tensors in "state", "action" and "next_state" dictionaries, along with "reward", will be concatenated in dimension 0. any other custom keys specified in ``**kwargs`` will not be concatenated. .. seealso:: :class:`.Buffer` Note: `DistributedBuffer` does not support customizing storage device when using the default storage, since its safer to pass cpu tensors between RPC callers and callees. Note: Since ``append()`` operates on the local buffer, in order to append to the distributed buffer correctly, please make sure that your actor is also the local buffer holder, i.e. a member of the ``group`` Args: buffer_name: A unique name of your buffer for registration in the group. group: Process group which holds this buffer. buffer_size: Maximum local buffer size. storage: Custom storage, not compatible with `buffer_size` and `buffer_device`. """ super().__init__( buffer_size=buffer_size, buffer_device="cpu", storage=storage, **kwargs ) self.buffer_name = buffer_name = group assert group.is_member() # register services, so that we may access other buffers _name = "/" + group.get_cur_name() + _name + "/_size_service", self._size_service) buffer_name + _name + "/_clear_service", self._clear_service ) buffer_name + _name + "/_sample_service", self._sample_service ) self.wr_lock = RLock()
[docs] def store_episode( self, episode: List[Union[TransitionBase, Dict]], required_attrs=("state", "action", "next_state", "reward", "terminal"), ): # DOC INHERITED with self.wr_lock: super().store_episode(episode, required_attrs=required_attrs)
[docs] def clear(self): """ Clear current local buffer. """ with self.wr_lock: return super().clear()
[docs] def all_clear(self): """ Remove all entries from all local buffers. """ future = [ + "/" + m + "/_clear_service") for m in ] for fut in future: fut.wait()
[docs] def size(self): """ Returns: Length of current local buffer. """ with self.wr_lock: return super().size()
[docs] def all_size(self): """ Returns: Total length of all buffers. """ future = [] count = 0 for m in future.append( self.buffer_name + "/" + m + "/_size_service" ) ) for fut in future: count += fut.wait() return count
[docs] def sample_batch( self, batch_size: int, concatenate: bool = True, device: Union[str, t.device] = "cpu", sample_method: Union[Callable, str] = "random_unique", sample_attrs: List[str] = None, additional_concat_custom_attrs: List[str] = None, *_, **__, ) -> Tuple[int, Union[None, tuple]]: # DOC INHERITED p_num = local_batch_size = _round_up(batch_size / p_num) future = [ self.buffer_name + "/" + m + "/_sample_service", args=(local_batch_size, sample_method), ) for m in ] results = [fut.wait() for fut in future] all_batch_size = sum([r[0] for r in results]) all_batch = list(it.chain(*[r[1] for r in results])) if sample_attrs is None: sample_attrs = all_batch[0].keys() if additional_concat_custom_attrs is None: additional_concat_custom_attrs = [] return ( all_batch_size, self.post_process_batch( all_batch, device, concatenate, sample_attrs, additional_concat_custom_attrs, ), )
def _size_service(self): # pragma: no cover return self.size() def _clear_service(self): # pragma: no cover self.clear() def _sample_service(self, batch_size, sample_method): # pragma: no cover if isinstance(sample_method, str): if not hasattr(self, "sample_method_" + sample_method): raise RuntimeError( f"Cannot find specified sample method: {sample_method}" ) sample_method = getattr(self, "sample_method_" + sample_method) with self.wr_lock: local_batch_size, local_batch = sample_method(batch_size) else: with self.wr_lock: local_batch_size, local_batch = sample_method(self, batch_size) if not isinstance(, TransitionStorageBasic): # for safety local_batch = ["cpu") for transition in local_batch] return local_batch_size, local_batch