Data flow in machin¶
Author: Muhan Li
Data flow is the major thing you should be very careful with while using the Machin library. Especially:
Data types
Tensor shapes
How to correctly store transitions.
How to correctly update your model.
If you are using the distributed algorithms, such as A3C
, IMPALA
, etc.
You should additionally take care of:
How to setup the distributed world correctly.
How to setup the distributed framework correctly.
How to perform synchronization, pass data, between processes.
In this tutorial, we are not going to cover the distributed part, and will focus on the data flow in single agent RL algorithms.
The big picture¶
To give you a general idea of the data flow model in single agent RL algorithms, we will take the DQN framework as an example and use a diagram to illustrate everything:
There are mainly three types of arrows in the diagram:
- The normal grey arrow: Represents data passed to functions by arguments, keywordarguments, etc. And data returned by functions.
- The dashed gray arrow: The dashed gray arrow between the Q network and the targetQ network means “soft_update”, which updates the parameters of the target Q networkby interpolating target Q and online Q.
- The circle gray arrows: There are two circled gray arrows:N_State —() QNet_targetState —() QNetThese two circle gray arrows are special network calls named “safe_call”, “safe_call”is an enhanced keyword-argument-like caller, it will inspect arguments of the “forward”function of your network, and fillin sub-keys of major attributes like action,state, next_state from the batched sample. What major attributes are going tobe used by “safe_call” depends on the used RL framework and the model.If sub-keys defined in major attributes are present in arguments of the “forward” function.then the corresponding sub-values will be used. Otherwise they would be ignored.“safe_call” will also inspect the input/output device of your model, or try to automaticallydetermine them, if they are not specified. Therefore, the tensors stored as sub-values ofthese major attributes could be correctly moved to the target device.
Dive deeper¶
So what is happening under the hood exactly? How do we pass the observations and actions to the framework, then expect it to perform some magical operation and train our models behind the scenes? In this section, we are going to cover all of these questions in the following order, this order is also the direction of our data flow:
Transition -> Buffer -> Algorithm
Transition¶
Now let’s take a step back and reexamine the process of a MDP (Markov Decision Process). A MDP process could be described as a chain of transition steps.
In Machin, we store each transition step as a TransitionBase
object, this
class manages all data of a user defined transition step, by categorizing data into
three types: major attribute, sub attribute and custom attribute.
Major attribute:
Dict[str, t.Tensor]
, used to describe complex state and action information.Sub attributes:
Union[Scalar, t.Tensor]
, used to store less complex states such as reward, terminal status, etc.Custom attributes:
Any
, used to store custom data structures describing environmental specific states, must not have tensors inside.
the default transition implementation is Transition
, which have 5 attributes:
state (major attribute)
action (major attribute)
next_state (major attribute)
reward (sub attribute)
terminal (sub attribute)
Note:: The first dimension of tensors stored in major attributes and sub attributes
must mean batch size (Scalar sub attributes are safe). Currently, the constructor of the default
transition implementation Transition
requires batch size to be 1, all algorithms
are only tested and validated with batch size equals to 1. Scalar type custom attributes, like
reward and terminal, will be considered as a tensor with shape [1, 1]
.
Now that we have a very general transition data structure, which supports storing:
complex state information, such as visual(RGB-D), audio, physical(position, velocity, etc.), internal states of recurrent networks, etc.
complex action information, whether discreet or contiguous, single space or a combination of multitude of spaces, by storing them in different keys of the dictionary.
complex reward, whether scalar reward or vectorized reward.
We may use this class to store the transition steps of a full MDP. Transition
can
be constructed like:
old_state = state = t.zeros([1, 5])
action = t.zeros([1, 2])
transition = {
"state": {"state": old_state},
"action": {"action": action},
"next_state": {"state": state},
"reward": 1.0,
"terminal": False
}
transition = Transition(**transition)
During Transition
instance initialization, tensors stored in major and sub attributes
will be cloned then detached, custom attributes will be deep copied.
Transition
also supports Transition.to()
method to move
internal tensors to the target pytorch device.
Buffer¶
Buffers (replay memory) is one of the core parts of the Machin library. Machin provides a sophisticated but clear implementation of replay memory, to accommodate the needs of different frameworks. In The big picture section, we have showed that the buffer instance encapsulated in the DQN framework has two major APIs: “append” and “sample”,
Append¶
Append is encapsulated by every framework, in their “store_*” APIs, some frameworks might will add new attributes to the constructed transition object in there “store_*” APIs, then call the “append” API of the buffer to add one or more transition objects to the buffer.
There are multiple buffer implementations, the basic Buffer
class implements a simple
ring buffer. PrioritizedBuffer
extends on the the basic Buffer
class with
a prioritized weight tree. Distributed buffers are more interesting and complex because data
are distributed on all process members.
In conclusion, the “append” API just stores one or more transition objects into the buffer, there are many internal events happening behind the scenes, and you need not worry about them.
Sample¶
Sampling is the first step performed in almost every frameworks, it may look like:
batch_size, (state, action, reward, next_state, terminal, others) = \
self.replay_buffer.sample_batch(self.batch_size,
concatenate_samples,
sample_method="random_unique",
sample_attrs=[
"state", "action",
"reward", "next_state",
"terminal", "*"
])
What secret actions does this segment of code perform internally? Well, nothing
other than “sampling” and “concatenation”. Argument sample_method
indicates
the sample selection method, sample_attrs
indicates which attributes of each
sample we are going to acquire, “*” is a wildcard selector picking
up all unspecified attributes.
Then what does “concatenation” mean? To put it simply, it will only affect “major attributes”
and “sub attributes” of each sample, if you have specified additional_concat_attrs
, then
custom attributes can also be concatenated into a tensor. We may use a graph to explain this
process happening in the basic Buffer
class:
Apart from the simplest Buffer
, there is also PrioritizedBuffer
(for
prioritized experience replay), DistributedBuffer
used in IMPALA
,
and DistributedPrioritizedBuffer
used in DQNApex
and DDPGApex
.
We will not discuss about the internal implementations of distributed buffers here.
Algorithm¶
Now that algorithms have got samples from buffers, they can start training their models. The three types of model free RL algorithms supported by Machin have three respective internal data path.
For more detailed descriptions of data paths and model requirements of all RL algorithms, please refer to Algorithm model requirements.
In order to bridge the gap between models and algorithms, Machin uses a function named safe_call()
to pass data from algorithms to your models, and uses different class methods defined in
algorithms like DDPG.action_transform_function()
to pack up raw data from your models
before using them in the algorithm framework. With this design, Machin is able to
achieve API consistency between algorithms while maintaining code simplicity.
Again, lets take the classic DQN
framework as an example, we will use mode="double"
here, so that a double DQN framework will be initialized, the models used
in the DQN
framework are Q networks. Q networks should accept a state
and
return value
for each possible discreet action, ideally we would like to define the model
according to this description exactly, like the one below, which accepts a single state
argument in its forward()
function, and returns a value tensor:
class QNet(nn.Module):
def __init__(self, state_dim, action_num):
super(QNet, self).__init__()
self.fc1 = nn.Linear(state_dim, 16)
self.fc2 = nn.Linear(16, 16)
self.fc3 = nn.Linear(16, action_num)
def forward(self, state):
a = t.relu(self.fc1(state))
a = t.relu(self.fc2(a))
return self.fc3(a)
And now in the DQN.update()
method, we have sampled a batch of state
, action
,
next_state
etc, to train this Q network:
batch_size, (state, action, reward, next_state, terminal, others) = \
self.replay_buffer.sample_batch(self.batch_size,
concatenate_samples,
sample_method="random_unique",
sample_attrs=[
"state", "action",
"reward", "next_state",
"terminal", "*"
])
where major attributes like state
, action
, next_state
are dictionaries of tensors,
while sub attributes like reward
and terminal
are two tensors of shape [batch_size, 1]
,
we will ignore others
for now, because if you are not inheriting from the DQN framework and
write your own DQN.reward_func()
, others
does nothing.
In order to get the target Q value, which is used as an value estimation of the next state, we
must use the Q network / the target Q network to criticize the sampled next_state
:
q_value = self.criticize(state)
DQN.criticize()
internally calls safe_call()
:
if use_target:
return safe_call(self.qnet_target, state)
else:
return safe_call(self.qnet, state)
safe_call()
is a relatively complex function, it does the following things in general:
- Check input & output device of your model, if they are not defined, try toautomatically determine them by checking locations of all parameters.
- Check argument names of the
forward
method of your model, this step will failif it is not defined or your model is aJIT
model complied by pytorch. - Try to resolve values of arguments by looking them up in the passed dictionaries,Additional keys in dictionaries that does not belong to args will be ignored.
Therefore, the sampled state
must have the required key: “state”, and “state” is the
first argument (exclude self) of QNet.forwrad
.
After forwarding, the Q network will pass predicted Q values back to the DQN framework,
and data path is complete, the result Q values of next step will be passed to DQN.reward_func()
to calculate target Q values, and then new values will be used to train the online Q network.
Summary¶
Generally speaking, Just treat all above process as an “advanced kwargs call”, During sampling, you will interact with your environment, and store some state tensors as values in a dictionary:
old_state = state = t.zeros([1, 5])
action = t.zeros([1, 2])
for _ in range(100):
dqn.store_transition({
"state": {"state": old_state},
"action": {"action": action},
"next_state": {"state": state},
"reward": 1.0,
"terminal": False
})
Then during training, you will invoke the update method of your framework, and it will concatenate states, actions, and next states in the first dimension:
batch_size, (state, action, reward, next_state, terminal, others) = \
self.replay_buffer.sample_batch(self.batch_size,
concatenate_samples,
sample_method="random_unique",
sample_attrs=[
"state", "action",
"reward", "next_state",
"terminal", "*"
])
# state = {"state": t.zeros([batch_size, 5])}
# action = {"action": t.zeros([batch_size, 2])}
# next_state = {"state": t.zeros([batch_size, 5])}
Then states, actions, and next states will be passed to your networks, safely, since tensors will be automatically moved to your model’s input device, and input device can be automatically determined or manually specified:
# DQN
q_value = self.criticize(state)
# DDPG
next_value = self.criticize(next_state, next_action, True)
# PPO
__, new_action_log_prob, new_action_entropy, *_ = \
self.eval_act(state, action)
...
value = self.criticize(state)
And criticized values will be used to update your networks, done.