from typing import Union, Dict, List, Tuple, Any
from threading import RLock
from collections import OrderedDict
from ..transition import TransitionBase
from .prioritized_buffer import PrioritizedBuffer
from machin.parallel.distributed import RpcGroup
import numpy as np
import torch as t
[docs]class DistributedPrioritizedBuffer(PrioritizedBuffer):
def __init__(
self,
buffer_name: str,
group: RpcGroup,
buffer_size: int = 1000000,
epsilon: float = 1e-2,
alpha: float = 0.6,
beta: float = 0.4,
beta_increment_per_sampling: float = 0.001,
**kwargs
):
"""
Create a distributed prioritized replay buffer instance.
To avoid issues caused by tensor device difference, all transition
objects are stored in device "cpu".
Distributed prioritized replay buffer constitutes of many local buffers
held per process, since it is very inefficient to maintain a weight
tree across processes, each process holds a weight tree of records in
its local buffer and a local buffer (same as ``DistributedBuffer``).
The sampling process(es) will first use rpc to acquire the wr_lock,
signalling "stop" to appending performed by actor processes,
then perform a sum of all local weight trees, and finally perform
sampling, after sampling and updating the importance weight,
the lock will be released.
During sampling, the tensors in "state", "action" and "next_state"
dictionaries, along with "reward", will be concatenated in dimension 0.
any other custom keys specified in ``**kwargs`` will not be
concatenated.
.. seealso:: :class:`PrioritizedBuffer`
Note:
`DistributedPrioritizedBuffer` does not support customizing storage as it
requires a linear storage.
Note:
:class:`DistributedPrioritizedBuffer` is not split into an
accessor and an implementation, because we would like to operate
on the buffer directly, when calling "size()" or "append()", to
increase efficiency (since rpc layer is bypassed).
Args:
buffer_name: A unique name of your buffer for registration in the group.
group: Process group which holds this buffer.
buffer_size: Maximum local buffer size.
epsilon: A small positive constant used to prevent edge-case
zero weight transitions from never being visited.
alpha: Prioritization weight. Used during transition sampling:
:math:`j \\sim P(j)=p_{j}^{\\alpha} / \
\\sum_i p_{i}^{\\alpha}`.
When ``alpha = 0``, all samples have the same probability
to be sampled.
When ``alpha = 1``, all samples are drawn uniformly according
to their weight.
beta: Bias correcting weight. When ``beta = 1``, bias introduced
by prioritized replay will be corrected. Used during
importance weight calculation:
:math:`w_j=(N \\cdot P(j))^{-\\beta}/max_i w_i`
beta_increment_per_sampling:
Beta increase step size, will gradually increase ``beta`` to 1.
"""
super().__init__(
buffer_size=buffer_size,
buffer_device="cpu",
epsilon=epsilon,
alpha=alpha,
beta=beta,
beta_increment_per_sampling=beta_increment_per_sampling,
**kwargs
)
self.buffer_name = buffer_name
self.buffer_version_table = np.zeros([buffer_size], dtype=np.uint64)
self.group = group
assert group.is_member()
# register services, so that we may access other buffers
_name = "/" + group.get_cur_name()
self.group.register(buffer_name + _name + "/_size_service", self._size_service)
self.group.register(
buffer_name + _name + "/_clear_service", self._clear_service
)
self.group.register(
buffer_name + _name + "/_weight_sum_service", self._weight_sum_service
)
self.group.register(
buffer_name + _name + "/_update_priority_service",
self._update_priority_service,
)
self.group.register(
buffer_name + _name + "/_sample_service", self._sample_service
)
self.wr_lock = RLock()
[docs] def store_episode(
self,
episode: List[Union[TransitionBase, Dict]],
priorities: Union[List[float], None] = None,
required_attrs=("state", "action", "next_state", "reward", "terminal"),
):
# DOC INHERITED
with self.wr_lock:
super(PrioritizedBuffer, self).store_episode(episode, required_attrs)
episode_number = self.episode_counter - 1
positions = self.episode_transition_handles[episode_number]
if priorities is None:
for position in positions:
# the initialization method used in the original essay
priority = self.wt_tree.get_leaf_max()
self.wt_tree.update_leaf(
self._normalize_priority(priority), position
)
# increase the version counter to mark it as tainted
# later priority update will ignore this position
self.buffer_version_table[position] += 1
else:
for priority, position in zip(priorities, positions):
self.wt_tree.update_leaf(
self._normalize_priority(priority), position
)
self.buffer_version_table[position] += 1
[docs] def size(self):
"""
Returns:
Length of current local buffer.
"""
with self.wr_lock:
return super().size()
[docs] def all_size(self):
"""
Returns:
Total length of all buffers.
"""
future = []
count = 0
for m in self.group.get_group_members():
future.append(
self.group.registered_async(
self.buffer_name + "/" + m + "/_size_service"
)
)
for fut in future:
count += fut.wait()
return count
[docs] def clear(self):
"""
Remove all entries from current local buffer.
"""
with self.wr_lock:
super().clear()
# also clear the version table
self.buffer_version_table.fill(0)
[docs] def all_clear(self):
"""
Remove all entries from all local buffers.
"""
future = [
self.group.registered_async(self.buffer_name + "/" + m + "/_clear_service")
for m in self.group.get_group_members()
]
for fut in future:
fut.wait()
[docs] def update_priority(self, priorities: np.ndarray, indexes: OrderedDict):
# DOC INHERITED
# update priority on all local buffers
future = []
offset = 0
# indexes is an OrderedDict, key is sampled process name,
# value is a tuple of an index np.ndarray and a version np.ndarray
for m, sub in indexes.items():
length = len(sub[0])
future.append(
self.group.registered_async(
self.buffer_name + "/" + m + "/_update_priority_service",
args=(priorities[offset : offset + length], sub[0], sub[1]),
)
)
offset += length
for fut in future:
fut.wait()
[docs] def sample_batch(
self,
batch_size: int,
concatenate: bool = True,
device: Union[str, t.device] = None,
sample_attrs: List[str] = None,
additional_concat_custom_attrs: List[str] = None,
*_,
**__
) -> Tuple[
int, Union[None, tuple], Union[None, Dict[str, Any]], Union[None, np.ndarray]
]:
# DOC INHERITED
if batch_size <= 0:
return 0, None, None, None
# calculate all weight sum
future = [
self.group.registered_async(
self.buffer_name + "/" + m + "/_weight_sum_service"
)
for m in self.group.get_group_members()
]
weights = [fut.wait() for fut in future]
all_weight_sum = sum(weights) + 1e-6 # prevent all zero
# determine the sampling size of local buffers, based on:
# local_weight_sum / all_weight_sum
ssize = np.ceil(np.array(weights) * batch_size / all_weight_sum)
ssize = [int(ss) for ss in ssize]
# collect samples and their priority
future = [
(
m,
self.group.registered_async(
self.buffer_name + "/" + m + "/_sample_service",
args=(ss, all_weight_sum),
),
)
for m, ss in zip(self.group.get_group_members(), ssize)
]
all_batch_len = 0
all_batch = []
all_index = OrderedDict()
all_is_weight = []
for m, fut in future:
batch_len, batch, index, version, is_weight = fut.wait()
if batch_len == 0:
continue
all_batch_len += batch_len
all_batch += batch
all_is_weight.append(is_weight)
# store them together to make API compatible with PrioritizedBuffer
all_index[m] = (index, version)
if all_batch_len == 0:
return 0, None, None, None
all_batch = self.post_process_batch(
all_batch, device, concatenate, sample_attrs, additional_concat_custom_attrs
)
all_is_weight = np.concatenate(all_is_weight, axis=0)
return all_batch_len, all_batch, all_index, all_is_weight
def _size_service(self): # pragma: no cover
with self.wr_lock:
return super().size()
def _clear_service(self): # pragma: no cover
with self.wr_lock:
super().clear()
# also clear the version table
self.buffer_version_table.fill(0)
def _weight_sum_service(self): # pragma: no cover
with self.wr_lock:
return self.wt_tree.get_weight_sum()
def _update_priority_service(
self, priorities, indexes, versions
): # pragma: no cover
with self.wr_lock:
# compare original entry versions to the current version table
is_same = self.buffer_version_table[indexes] == versions
# select unchanged entries
priorities = priorities[is_same]
indexes = indexes[is_same]
super().update_priority(priorities, indexes)
def _sample_service(self, batch_size, all_weight_sum): # pragma: no cover
# the local batch size
with self.wr_lock:
if batch_size <= 0 or len(self.storage) == 0:
return 0, None, None, None, None
index, is_weight = self.sample_index_and_weight(batch_size, all_weight_sum)
version = self.buffer_version_table[index]
batch = [self.storage[idx] for idx in index]
return len(batch), batch, index, version, is_weight