Source code for machin.frame.transition

from typing import Union, Dict, Iterable, Any, NewType
from itertools import chain
import torch as t
import numpy as np

Scalar = NewType("Scalar", Union[int, float, bool])

[docs]class TransitionBase: """ Base class for all transitions """ _inited = False def __init__( self, major_attr: Iterable[str], sub_attr: Iterable[str], custom_attr: Iterable[str], major_data: Iterable[Dict[str, t.Tensor]], sub_data: Iterable[Union[Scalar, t.Tensor]], custom_data: Iterable[Any], ): """ Note: Major attributes store things like state, action, next_states, etc. They are usually **concatenated by their dictionary keys** during sampling, and passed as keyword arguments to actors, critics, etc. Sub attributes store things like terminal states, reward, etc. They are usually **concatenated directly** during sampling, and used in different algorithms. Custom attributes store not concatenatable values, usually user specified states, used in models or as special arguments in different algorithms. They will be collected together as a list during sampling, **no further concatenation is performed**. Args: major_attr: A list of major attribute names. sub_attr: A list of sub attribute names. custom_attr: A list of custom attribute names. major_data: Data of major attributes. sub_data: Data of sub attributes. custom_data: Data of custom attributes. """ self._major_attr = list(major_attr) self._sub_attr = list(sub_attr) self._custom_attr = list(custom_attr) self._keys = self._major_attr + self._sub_attr + self._custom_attr self._length = len(self._keys) self._batch_size = None for attr, data in zip( chain(major_attr, sub_attr, custom_attr), chain(major_data, sub_data, custom_data), ): object.__setattr__(self, attr, data) # will trigger _check_validity in __setattr__ self._inited = True self._detach() def __len__(self): return self._length def __getitem__(self, item): return getattr(self, item) def __setitem__(self, key, value): if key not in self._keys: raise RuntimeError( "You cannot dynamically set new attributes in" "a Transition object!" ) object.__setattr__(self, key, value) self._check_validity() def __setattr__(self, key, value): if not self._inited: object.__setattr__(self, key, value) else: if key not in self._keys: raise RuntimeError( "You cannot dynamically set new attributes in" "a Transition object!" ) if self._inited: self._check_validity() @property def major_attr(self): return self._major_attr @property def sub_attr(self): return self._sub_attr @property def custom_attr(self): return self._custom_attr
[docs] def keys(self): """ Returns: All attribute names in current transition object. Ordered in: "major_attrs, sub_attrs, custom_attrs" """ return self._keys
[docs] def items(self): """ Returns: All attribute values in current transition object. """ for k in self._keys: yield k, getattr(self, k)
[docs] def has_keys(self, keys: Iterable[str]): """ Args: keys: A list of keys Returns: A bool indicating whether current transition object contains all specified keys. """ return all([k in self._keys for k in keys])
[docs] def to(self, device: Union[str, t.device]): """ Move current transition object to another device. will be a no-op if it already locates on that device. Args: device: A valid pytorch device. Returns: Self. """ for ma in self._major_attr: ma_data = getattr(self, ma) for k, v in ma_data.items(): ma_data[k] = for sa in self._sub_attr: sa_data = getattr(self, sa) if t.is_tensor(sa_data): object.__setattr__(self, sa, return self
def _detach(self): """ Detach all tensors in major attributes and sub attributes, put data of all attributes in place, but do not copy them. Returns: Self. """ for ma in self._major_attr: ma_data = getattr(self, ma) for k, v in ma_data.items(): ma_data[k] = v.detach() for sa in self._sub_attr: sa_data = getattr(self, sa) if t.is_tensor(sa_data): object.__setattr__(self, sa, sa_data.detach()) for ca in self._custom_attr: ca_data = getattr(self, ca) object.__setattr__(self, ca, ca_data) return self def _check_validity(self): """ Check validity of current transition object, will check batch size, and major attributes' data, sub attributes' data. Raises: ``ValueError`` if anything is invalid. """ batch_size = None for ma in self._major_attr: ma_data = getattr(self, ma) for k, v in ma_data.items(): if not t.is_tensor(v) or v.dim() < 1: raise ValueError( f'Key "{k}" of transition major attribute "{ma}" ' "is an invalid tensor" ) if batch_size is None: batch_size = v.shape[0] else: if batch_size != v.shape[0]: raise ValueError( f'Key "{k}" of transition major attribute "{ma}" ' f"has invalid batch size {v.shape[0]}." ) for sa in self._sub_attr: sa_data = getattr(self, sa) if np.isscalar(sa_data): # will return true for inbuilt scalar types # like int, bool, float if batch_size != 1: raise ValueError( "Transition sub attribute " f'"{sa}" is a scalar, but batch size is {batch_size}.' ) elif t.is_tensor(sa_data): if sa_data.dim() < 1: raise ValueError( f'Transition sub attribute "{sa}" is an invalid tensor.' ) elif sa_data.shape[0] != batch_size: raise ValueError( "Transition sub attribute " f'"{sa}" has invalid batch size {sa_data.shape[0]}.' ) else: raise ValueError( f'Transition sub attribute "{sa}" has invalid ' f"value {sa_data}, requires scalar or tensor." ) object.__setattr__(self, "_batch_size", batch_size)
[docs]class Transition(TransitionBase): """ The default Transition class. Have three main attributes: ``state``, ``action`` and ``next_state``. Have two sub attributes: ``reward`` and ``terminal``. Store one transition step of one agent. """ # for auto suggestion in IDEs state = None # type: Dict[str, t.Tensor] action = None # type: Dict[str, t.Tensor] next_state = None # type: Dict[str, t.Tensor] reward = None # type: Union[float, t.Tensor] terminal = None # type: bool def __init__( self, state: Dict[str, t.Tensor], action: Dict[str, t.Tensor], next_state: Dict[str, t.Tensor], reward: Union[float, t.Tensor], terminal: bool, **kwargs, ): """ Args: state: Previous observed state. action: Action of agent. next_state: Next observed state. reward: Reward of agent. terminal: Whether environment has reached terminal state. **kwargs: Custom attributes. They are ordered in the alphabetic order (provided by ``sort()``) when you call ``keys()``. Note: You should not store any tensor inside ``**kwargs`` as they will not be moved to the sample output device. """ custom_keys = sorted(kwargs.keys()) assert isinstance(terminal, bool) or ( t.is_tensor(terminal) and terminal.dtype == t.bool ) super().__init__( major_attr=["state", "action", "next_state"], sub_attr=["reward", "terminal"], custom_attr=custom_keys, major_data=[state, action, next_state], sub_data=[reward, terminal], custom_data=[kwargs[k] for k in custom_keys], ) def _check_validity(self): # fix batch size to 1 super()._check_validity() if self._batch_size != 1: raise ValueError( "Batch size of the default transition " f"implementation must be 1, is {self._batch_size}" )