from typing import Union, Dict, List, Tuple, Any, Callable
from ..transition import (
TransitionBase,
Transition,
Scalar,
)
from .storage import TransitionStorageBase, TransitionStorageBasic
import torch as t
import random
[docs]class Buffer:
def __init__(
self,
buffer_size: int = 1000000,
buffer_device: Union[str, t.device] = "cpu",
storage: TransitionStorageBase = None,
**__,
):
"""
Create a buffer instance.
Buffer stores a series of transition objects and functions
as a ring buffer. **It is not thread-safe**.
See Also:
:class:`.Transition`
Args:
buffer_size: Maximum buffer size.
buffer_device: Device where buffer is stored.
storage: Custom storage, not compatible with `buffer_size` and
`buffer_device`.
"""
self.storage = (
TransitionStorageBasic(buffer_size, buffer_device)
if storage is None
else storage
)
self.transition_episode_number = {} # type: Dict[Any, int]
self.episode_transition_handles = {} # type: Dict[int, List[Any]]
self.episode_counter = 0
[docs] def store_episode(
self,
episode: List[Union[TransitionBase, Dict]],
required_attrs=("state", "action", "next_state", "reward", "terminal"),
):
"""
Store an episode to the buffer.
Note:
If you pass in a dict type transition object, it will be automatically
converted to ``Transition``, which requires attributes "state", "action"
"next_state", "reward" and "terminal" to be present in the dict keys.
Args:
episode: A list of transition objects.
required_attrs: Required attributes. Could be an empty tuple if
no attribute is required.
Raises:
``ValueError`` if episode is empty.
``ValueError`` if any transition object in the episode doesn't have
required attributes in ``required_attrs``.
"""
if len(episode) == 0:
raise ValueError("Episode must be non-empty.")
episode_number = self.episode_counter
self.episode_counter += 1
for idx, transition in enumerate(episode):
if isinstance(transition, dict):
transition = Transition(**transition)
elif isinstance(transition, TransitionBase):
pass
else: # pragma: no cover
raise ValueError(
"Transition object must be a dict or an instance"
" of the Transition class."
)
if not transition.has_keys(required_attrs):
missing_keys = set(required_attrs) - set(transition.keys())
raise ValueError(
f"Transition object missing attributes: {missing_keys}, "
f"object is {transition}."
)
episode[idx] = transition
# update episode version record
handles = self.storage.store_episode(episode)
for handle in handles:
try:
old_episode = self.transition_episode_number[handle]
except (KeyError, IndexError):
old_episode = None
# evict old episode
if old_episode is not None:
for old_position in self.episode_transition_handles[old_episode]:
self.transition_episode_number.pop(old_position)
self.episode_transition_handles.pop(old_episode)
self.transition_episode_number[handle] = episode_number
self.episode_transition_handles[episode_number] = handles
[docs] def size(self):
"""
Returns:
Length of current buffer.
"""
return len(self.storage)
[docs] def clear(self):
"""
Remove all entries from the buffer
"""
self.storage.clear()
[docs] def sample_batch(
self,
batch_size: int,
concatenate: bool = True,
device: Union[str, t.device] = "cpu",
sample_method: Union[
Callable[["Buffer", int], Tuple[List[Any], int]], str
] = "random_unique",
sample_attrs: List[str] = None,
additional_concat_custom_attrs: List[str] = None,
*_,
**__,
) -> Tuple[int, Union[None, tuple]]:
"""
Sample a random batch from buffer, and perform concatenation.
See Also:
Default sample methods are defined as instance methods.
:meth:`.Buffer.sample_method_random_unique`
:meth:`.Buffer.sample_method_random`
:meth:`.Buffer.sample_method_all`
Note:
"Concatenation" means ``torch.cat([list of tensors], dim=0)`` for tensors,
and ``torch.tensor([list of scalars]).view(batch_size, 1)`` for scalars.
By default, only major and sub attributes will be concatenated, in order to
concatenate custom attributes, specify their names in
`additional_concat_custom_attrs`.
Warnings:
Custom attributes must not contain tensors. And only scalar custom
attributes can be concatenated, such as ``int``, ``float``,
``bool``.
Args:
batch_size: A hint size of the result sample. actual sample size
depends on your sample method.
sample_method: Sample method, could be one of:
``"random", "random_unique", "all"``,
or a function:
``func(buffer, batch_size)->(list, result_size)``
concatenate: Whether perform concatenation on major, sub and custom
attributes.
If ``True``, for each value in dictionaries of major
attributes. and each value of sub attributes, returns
a concatenated tensor. Custom Attributes specified in
``additional_concat_custom_attrs`` will also be concatenated.
If ``False``, performs no concatenation.
device: Device to move tensors in the batch to.
sample_attrs: If sample_keys is specified, then only specified keys
of the transition object will be sampled. You may use
``"*"`` as a wildcard to collect remaining
**custom keys** as a ``dict``, you cannot collect major
and sub attributes using this.
Invalid sample attributes will be ignored.
additional_concat_custom_attrs: additional **custom keys** needed to be
concatenated, will only work if ``concatenate`` is
``True``.
Returns:
1. Batch size, Sampled attribute values in the same order as
``sample_keys``.
2. Sampled attribute values is a tuple. Or ``None`` if sampled
batch size is zero (E.g.: if buffer is empty or your sample
size is 0 and you are not sampling using the "all" method).
- For major attributes, result are dictionaries of tensors with
the same keys in your transition objects.
- For sub attributes, result are tensors.
- For custom attributes, if they are not in
``additional_concat_custom_attrs``, then lists, otherwise tensors.
- For wildcard selector, result is a dictionary containing unused custom
attributes, if they are not in ``additional_concat_custom_attrs``,
the values are lists, otherwise values are tensors.
"""
if isinstance(sample_method, str):
if not hasattr(self, "sample_method_" + sample_method):
raise RuntimeError(
f"Cannot find specified sample method: {sample_method}"
)
sample_method = getattr(self, "sample_method_" + sample_method)
batch_size, batch = sample_method(batch_size)
else:
batch_size, batch = sample_method(self, batch_size)
return (
batch_size,
self.post_process_batch(
batch, device, concatenate, sample_attrs, additional_concat_custom_attrs
),
)
[docs] def sample_method_random_unique(
self, batch_size: int,
) -> Tuple[int, List[Transition]]:
"""
Sample unique random samples from buffer.
Note:
Sampled size could be any value from 0 to ``batch_size``.
"""
batch_size = min(len(self.transition_episode_number), batch_size)
batch_handles = random.sample(
list(self.transition_episode_number.keys()), k=batch_size
)
batch = [self.storage[bh] for bh in batch_handles]
return batch_size, batch
[docs] def sample_method_random(self, batch_size: int,) -> Tuple[int, List[Transition]]:
"""
Sample random samples from buffer.
Note:
Sampled size could be any value from 0 to ``batch_size``.
"""
batch_size = min(len(self.transition_episode_number), batch_size)
batch_handles = random.choices(
list(self.transition_episode_number.keys()), k=batch_size
)
batch = [self.storage[bh] for bh in batch_handles]
return batch_size, batch
[docs] def sample_method_all(self, _,) -> Tuple[int, List[Transition]]:
"""
Sample all samples from buffer, will ignore the ``batch_size`` parameter.
"""
batch = [self.storage[bh] for bh in self.transition_episode_number.keys()]
return len(self.transition_episode_number), batch
[docs] def post_process_batch(
self,
batch: List[Transition],
device: Union[str, t.device],
concatenate: bool,
sample_attrs: List[str],
additional_concat_custom_attrs: List[str],
):
"""
Post-process sampled batch.
"""
result = []
used_keys = []
if len(batch) == 0:
return None
if sample_attrs is None:
sample_attrs = batch[0].keys() if batch else []
if additional_concat_custom_attrs is None:
additional_concat_custom_attrs = []
major_attr = set(batch[0].major_attr)
sub_attr = set(batch[0].sub_attr)
custom_attr = set(batch[0].custom_attr)
for attr in sample_attrs:
if attr in major_attr:
tmp_dict = {}
for sub_k in batch[0][attr].keys():
tmp_dict[sub_k] = self.post_process_attribute(
attr,
sub_k,
self.make_tensor_from_batch(
self.pre_process_attribute(
attr,
sub_k,
[item[attr][sub_k].to(device) for item in batch],
),
device,
concatenate,
),
)
result.append(tmp_dict)
used_keys.append(attr)
elif attr in sub_attr:
result.append(
self.post_process_attribute(
attr,
None,
self.make_tensor_from_batch(
self.pre_process_attribute(
attr, None, [item[attr] for item in batch]
),
device,
concatenate,
),
)
)
used_keys.append(attr)
elif attr in custom_attr:
result.append(
self.post_process_attribute(
attr,
None,
self.make_tensor_from_batch(
self.pre_process_attribute(
attr, None, [item[attr] for item in batch]
),
device,
concatenate and attr in additional_concat_custom_attrs,
),
)
)
used_keys.append(attr)
elif attr == "*":
# select custom keys
tmp_dict = {}
for remain_k in custom_attr:
if remain_k not in used_keys:
tmp_dict[remain_k] = self.post_process_attribute(
attr,
None,
self.make_tensor_from_batch(
self.pre_process_attribute(
attr, None, [item[remain_k] for item in batch]
),
device,
concatenate
and remain_k in additional_concat_custom_attrs,
),
)
used_keys.append(remain_k)
result.append(tmp_dict)
return tuple(result)
[docs] def pre_process_attribute(
self, attribute: Any, sub_key: Any, values: List[Union[Scalar, t.Tensor]]
):
"""
Pre-process attribute items, method :meth:`.Buffer.make_tensor_from_batch`
will use the result from this function and assumes processed attribute items
to be one of:
1. A list of tensors that's concatenable in dimension 0.
2. A list of values that's transformable to a tensor.
In case you want to implement custom padding for each item of an
attribute, or other custom preprocess, please override this method.
See Also:
`This issue <https://github.com/iffiX/machin/issues/8>`_
Args:
attribute: Attribute key, such as "state", "next_state", etc.
sub_key: Sub key in attribute if attribute is a major attribute,
set to `None` if attribute is a sub attribute or a custom attribute.
values: Sampled lists of attribute items.
"""
return values
[docs] def make_tensor_from_batch(
self,
batch: List[Union[Scalar, t.Tensor]],
device: Union[str, t.device],
concatenate: bool,
):
"""
Make a tensor from a batch of data.
Will concatenate input tensors in dimension 0.
Or create a tensor of size (batch_size, 1) for scalars.
Args:
batch: Batch data.
device: Device to move data to
concatenate: Whether performing concatenation.
Returns:
Original batch if batch is empty,
or tensor depends on your data (if concatenate),
or original batch (if not concatenate).
"""
if concatenate and len(batch) != 0:
item = batch[0]
batch_size = len(batch)
if t.is_tensor(item):
batch = [it.to(device) for it in batch]
return t.cat(batch, dim=0).to(device)
else:
try:
return t.tensor(batch, device=device).view(batch_size, -1)
except Exception:
raise ValueError(f"Batch not concatenable: {batch}")
else:
return batch
[docs] def post_process_attribute(
self,
attribute: Any,
sub_key: Any,
values: Union[List[Union[Scalar, t.Tensor]], t.Tensor],
):
"""
Post-process concatenated attribute items. Values are processed results from
the method :meth:`.Buffer.make_tensor_from_batch`, either a list of not
concatenated values, or a concatenated tensor.
Args:
attribute: Attribute key, such as "state", "next_state", etc.
sub_key: Sub key in attribute if attribute is a major attribute,
set to `None` if attribute is a sub attribute or a custom attribute.
values: (Not) Concatenated attribute items.
"""
return values