Machin is a reinforcement library purely based on pytorch,
it is designed with three things in mind:
Easy to understand.
Easy to extend.
Easy to reuse.
The first goal is achieved through clear structure design, robust document,
and concise description of use cases. The second goal is achieved through
adding an extra layer upon basic apis provided in the distributed module of
pytorch, this layer offers additional fault tolerance mechanism and
eliminates hassles occurring in distributed programming. The last goal is
the result of modular designs, careful api arrangements, and experiences
gathered from other similar projects.
Compared to other versatile and powerful reinforcement learning frameworks,
Machin tries to offer a pleasant programming experience, smoothing out
as many obstacles involved in reinforcement learning and distributed
programming as possible. Some essential functions such as automated tuning and
neural architecture search are not offered in this package, we strongly
recommend you take a look at these amazing projects and take a piggyback ride:
If you are using PIP to manage your python packages, you may directly type:
pipinstallmachin
If you are using conda to manage your python packages, you are suggested to create a
virtual environment first, to prevent PIP changes your packages without letting
conda know:
The agent has to decide between two actions - moving the cart left or
right - so that the pole attached to it stays upright. You can find an
official leaderboard with various algorithms and visualizations at the
Gym website.
As the agent observes the current state of the environment and chooses
an action, the environment transitions to a new state, and also
returns a reward that indicates the consequences of the action. In this
task, rewards are +1 for every incremental timestep and the environment
terminates if the pole falls over too far or the cart moves more then 2.4
units away from center. This means better performing scenarios will run
for longer duration, accumulating larger return.
The CartPole task is designed so that the inputs to the agent are 4 real
values representing the environment state (position, velocity, etc.).
However, neural networks can solve the task purely by looking at the
scene, so we’ll use a patch of the screen centered on the cart as an
input. Because of this, our results aren’t directly comparable to the
ones from the official leaderboard - our task is much harder.
Unfortunately this does slow down the training, because we have to
render all the frames.
Strictly speaking, we will present the state as the difference between
the current screen patch and the previous one. This will allow the agent
to take the velocity of the pole into account from one image.
Our environment is deterministic, so all equations presented here are
also formulated deterministically for the sake of simplicity. In the
reinforcement learning literature, they would also contain expectations
over stochastic transitions in the environment.
Our aim will be to train a policy that tries to maximize the discounted,
cumulative reward
\(R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t\), where
\(R_{t_0}\) is also known as the return. The discount,
\(\gamma\), should be a constant between \(0\) and \(1\)
that ensures the sum converges. It makes rewards from the uncertain far
future less important for our agent than the ones in the near future
that it can be fairly confident about.
The main idea behind Q-learning is that if we had a function
\(Q^*: State \times Action \rightarrow \mathbb{R}\), that could tell
us what our return would be, if we were to take an action in a given
state, then we could easily construct a policy that maximizes our
rewards:
\[\pi^*(s) = \arg\!\max_a \ Q^*(s, a)\]
However, we don’t know everything about the world, so we don’t have
access to \(Q^*\). But, since neural networks are universal function
approximators, we can simply create one and train it to resemble
\(Q^*\).
For our training update rule, we’ll use a fact that every \(Q\)
function for some policy obeys the Bellman equation:
\[Q^{\pi}(s, a) = r + \gamma Q^{\pi}(s', \pi(s'))\]
The difference between the two sides of the equality is known as the
temporal difference error, \(\delta\):
The DQN framework is defined in machin.frame.algorithms.dqn, you may import it
with the following statements:
frommachin.frame.algorithmsimportDQN# Or with the following statementfrommachin.frame.algorithms.dqnimportDQN
DQN framework is one of the three major types of model-free reinforcement methods
supported by Machin. To initialize it, you must at least provide a Q network, a
target Q network, an optimizer used to optimize the first Q network, and a
criterion used to determine distance between the estimated Q value and the target
Q value we would like to reach:
DQN framework supports multiple mode s, the mode parameter could be one of
“vanilla”, “fixed_target” or “double”, for more detailed explanations on these
mode s, please refer to DQN.
Depending on the Q framework mode, your network configurations might be a little
different, by generally speaking, your Q network should accept a state, and then
output estimated Q values for each action. A simple example would be:
Please take care of the function signature of forward, because the name of
its arguments will be examined when the DQN framework tries to perform a forward
operation on your Q network, during training or inference.
Now, please remember the name of the state argument: “some_state”.
In order to optimize your model, you must specify an optimizer and a criterion.
Usually the optimizer is torch.optim.Adam. We are going to use the good old
MSE loss nn.MSELoss here.
We have all the ingredients required to start the ignition sequence of the DQN
framework, lets mix these parts together:
The framework might will print two warnings for not setting the input/output
device of Q networks, but lets ignore that for now. You may quite Machin down
either by:
# to mark the input/output device Manually# will not work if you move your model to other devices# after wrappingq_net=static_module_wrapper(q_net,"cpu","cpu")q_net_t=static_module_wrapper(q_net_t,"cpu","cpu")
Or by:
# to mark the input/output device Automatically# will not work if you model locates on multiple devicesq_net=dynamic_module_wrapper(q_net)q_net_t=dynamic_module_wrapper(q_net_t)
static_module_wrapper and dynamic_module_wrapper can be imported from
machin.model.nets
The DQN framework has encapsulated a replay buffer inside, in order to interact with
the internal replay buffer, you may use either one of the following APIs, according to your
needs:
store_transition stores a single transition step in your MDP process, while
store_episode stores all transitions inside a MDP process.
When you are using other frameworks, these two APIs may both be supported, or only one of
them is supported, depending on the internal implementations of frameworks, and
requirements of algorithms.
Now lets take DQN as an example, each Transition object describes a single step of
a MDP process, and constitutes of 5 attributes:
state: State observed by your agent when transition begins.
action: Action taken by your agent in this transition step.
next_state: Next state observed by your agent, when action is taken.
reward: Incremental reward given to your agent, due to the taken action.
terminal: Whether the next state is the terminal state of current MDP.
Suppose the observation dimension of your agent is 5, contiguous,
within range \((-\infty, +\infty)\), and total number of available discreet actions is 3,
then an example transition step would be:
# some states observed by your agentold_state=state=t.zeros([1,5])# suppose action taken by your agent is 2, available actions are 0, 1, 2action=t.full([1,1],2,dtype=t.int)dqn.store_transition({"state":{"some_state":old_state},"action":{"action":action},"next_state":{"some_state":state},"reward":0.1,"terminal":False})
Please take note that the sub key of attribute “state” and “next_state”
must match the name of the state argument “some_state” in your Q network
mentioned above. And the sub key of attribute “action” must be “action”.
We will come back to this seemingly strange name requirement in the Buffer
section of Data flow in machin. For
now, please make sure that shapes and dictionary keys of your tensors are exactly the same
as the example.
With all the necessary parts, we can construct a full training program now:
frommachin.frame.algorithmsimportDQNfrommachin.utils.loggingimportdefault_loggerasloggerimporttorchastimporttorch.nnasnnimportgym# configurationsenv=gym.make("CartPole-v0")observe_dim=4action_num=2max_episodes=1000max_steps=200solved_reward=190solved_repeat=5# model definitionclassQNet(nn.Module):def__init__(self,state_dim,action_num):super(QNet,self).__init__()self.fc1=nn.Linear(state_dim,16)self.fc2=nn.Linear(16,16)self.fc3=nn.Linear(16,action_num)defforward(self,some_state):a=t.relu(self.fc1(some_state))a=t.relu(self.fc2(a))returnself.fc3(a)if__name__=="__main__":q_net=QNet(observe_dim,action_num)q_net_t=QNet(observe_dim,action_num)dqn=DQN(q_net,q_net_t,t.optim.Adam,nn.MSELoss(reduction='sum'))episode,step,reward_fulfilled=0,0,0smoothed_total_reward=0terminal=Falsewhileepisode<max_episodes:episode+=1total_reward=0terminal=Falsestep=0state=t.tensor(env.reset(),dtype=t.float32).view(1,observe_dim)whilenotterminalandstep<=max_steps:step+=1witht.no_grad():old_state=state# agent model inferenceaction=dqn.act_discrete_with_noise({"some_state":old_state})state,reward,terminal,_=env.step(action.item())state=t.tensor(state,dtype=t.float32).view(1,observe_dim)total_reward+=rewarddqn.store_transition({"state":{"some_state":old_state},"action":{"action":action},"next_state":{"some_state":state},"reward":reward,"terminal":terminalorstep==max_steps})# update, update more if episode is longer, else lessifepisode>100:for_inrange(step):dqn.update()# show rewardsmoothed_total_reward=(smoothed_total_reward*0.9+total_reward*0.1)logger.info("Episode {} total reward={:.2f}".format(episode,smoothed_total_reward))ifsmoothed_total_reward>solved_reward:reward_fulfilled+=1ifreward_fulfilled>=solved_repeat:logger.info("Environment solved!")exit(0)else:reward_fulfilled=0
And your Q network should will be successfully trained within about 300 episodes:
[2020-07-26 22:45:53,764] <INFO>:default_logger:Episode 226 total reward=188.18
[2020-07-26 22:45:54,405] <INFO>:default_logger:Episode 227 total reward=189.36
[2020-07-26 22:45:55,091] <INFO>:default_logger:Episode 228 total reward=190.42
[2020-07-26 22:45:55,729] <INFO>:default_logger:Episode 229 total reward=191.38
[2020-07-26 22:45:56,372] <INFO>:default_logger:Episode 230 total reward=192.24
[2020-07-26 22:45:57,012] <INFO>:default_logger:Episode 231 total reward=193.02
[2020-07-26 22:45:57,658] <INFO>:default_logger:Episode 232 total reward=193.72
[2020-07-26 22:45:57,658] <INFO>:default_logger:Environment solved!
Its a pain to lay down every details by hand, so why don’t we do it automatically?
We have PyTorch Lightning, which is a powerful
machine learning programming utility that enables users to write template like code
and leave details such as check-pointing, logging, metric evaluation to the Lightning
engine hidden behind. The Machin library also supports coding with the Lightning library
to simplify your RL workflow.
then you need to modify this config file to suit your needs, first we need to modify
the framework config stored under sub-key “frame_config”:
You need to define your QNet model in some file, suppose you defined it in qnet.py,
then in the same directory, set key “models” to [“qnet.QNet”, “qnet.QNet”], currently
this is the only way when you use Lightning with the automatic launcher. There are other
ways which are described below.
If your model has any initialization args and kwargs, then you will also need to
specify them for each one of your model.
Other keys corresponds to the initialization argument of DQN. Please refer
to its docstring for more information.
After modifying the framework config, you also need to modify the environemt config when you need to. The
environment config is provided to pytorch datasets which wraps the target environment defined in machin.auto.envs.
For example, the openai_gym module defines two dataset classes for discrete and continuous actions:
RLGymDiscActDataset and :class .RLGymContActDataset.
For the OpenAI Gym environemt, “act_kwargs” is the keyword arguments passed to act or act_*_with_noise
depending on your framework, please refer to RLGymDiscActDataset and :class .RLGymContActDataset
for their specific usage.
Finally, you may also want to modify other configuration keys passed to the Lightning framework:
“early_stopping_patience”: the maximum number of epochs where total reward does not increase before
terminating training, this value is passed to the EarlyStopping hook in Lightning.
“episode_per_epoch”: Number of episodes to run per epoch.
“max_episodes”: Number of maximum training episodes.
“root_dir”: Root directory to use for check-pointing, logging, etc. in this training. Must be unique
for each experiment, otherwise your results will be overwritten.
“gpus”: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node,
passed to pytorch_lightning.Trainer.
For distributed frameworks such as IMPALA, there are two additional configuration keys:
“num_nodes”: number of nodes (int) for distributed training,
passed to pytorch_lightning.Trainer.
“num_processes”: number of processes to run (int) on each node for distributed training,
only necessary when you set key “gpu” to null indicating that you will not use gpu but
cpu only, otherwise the created process number is equal to the gpu number specified.
The main limitation is that: you cannot use any custom environment except already provided ones. And it is also
inflexible for hyper parameter searching when you want to fine-tune your model. Therefore, we will first introduce
you to use the auto module programmatically instead of declaratively, then instruct you to create your own
environment extension.
Suppose that you want to sweep the hyper parameter space using some tuning library like
Microsoft NNI, then you program can be divided into
the following pseudo program:
This method generated some PyTorch Lightning specific configurations, which
are not provided by the generate_config method in algorithms classes,
the generate_config class method is designed to only initialize the algorithm
framework with the init_from_config class method of that algorithm class.
All framework related configurations are under the sub-key “frame_config”.
Note
You can pass the model class defined by you to the framework in the following
ways:
An nn.Module subclass
A string name of a global defined model class in any frame
of your call stack. (Not available if framework is distributed),
A string name of an importable model class, eg: foo.baz.model
example:
classQNet(nn.Module):...# specify directlyconfig["frame_config"]["models"]=[QNet,QNet]# specify as a string, since it is defined globally in the current stackconfig["frame_config"]["models"]=["QNet","QNet"]# specify as a importable name, current module is "__main__"config["frame_config"]["models"]=["__main__.QNet","__main__.QNet"]
Note
For optimizer, lr_scheduler, criterion, etc. you can specify them in the same way you
specify your models, they have an additional way to define: a valid string name of
some respective class in the PyTorch library, please refer to
assert_and_get_valid_optimizer() and assert_and_get_valid_lr_scheduler()
and assert_and_get_valid_criterion().
You can fill in your hyper parameters provided by NNI into the “frame_config” section.
All environment adaptors are located in machin.auto.envs, to create an environment extension,
you need to:
Create a python file with your environment name, such as “some_env.py”.
Update __init__.py in machin.auto.envs to import your environment
module as a whole, this is used to look up available environments.
For your environment module, you need to define 4 things:
A dataset class which inherits and implements methods defined in
RLDataset, when __next__ method is called, it must return
a sampled episode of type DatasetResult.
A dataset creator function which takes in a framework instance (such as
an instance of DQN) and pass this to the dataset so the framework can be
used internally to interact with your environment. It must return a dataset
class instance.
A function named generate_env_config which takes in a previous config
and add three keys: “env”, “train_env_config”, and “test_env_config”,
“env” is your environment name, and two configs are used to initialize
the test and train environment.
A launch function which takes in a config object and a list of PyTorch Lightning
callbacks, it is used to launch the experiment with PyTorch Lightning Trainer.
To give you a general idea of the data flow model in single agent RL algorithms,
we will take the DQN framework as an example and use a diagram to illustrate everything:
So what is happening under the hood exactly? How do we pass the observations
and actions to the framework, then expect it to perform some magical operation
and train our models behind the scenes? In this section, we are going to cover
all of these questions in the following order, this order is also the direction
of our data flow:
Now let’s take a step back and reexamine the process of a MDP (Markov Decision Process).
A MDP process could be described as a chain of transition steps.
In Machin, we store each transition step as a TransitionBase object, this
class manages all data of a user defined transition step, by categorizing data into
three types: major attribute, sub attribute and custom attribute.
Major attribute: Dict[str,t.Tensor], used to describe complex state and action information.
Sub attributes: Union[Scalar,t.Tensor], used to store less complex states such as reward, terminal status, etc.
Custom attributes: Any, used to store custom data structures describing environmental specific states, must not have tensors inside.
the default transition implementation is Transition, which have 5 attributes:
state (major attribute)
action (major attribute)
next_state (major attribute)
reward (sub attribute)
terminal (sub attribute)
Note:: The first dimension of tensors stored in major attributes and sub attributes
must mean batch size (Scalar sub attributes are safe). Currently, the constructor of the default
transition implementation Transitionrequires batch size to be 1, all algorithms
are only tested and validated with batch size equals to 1. Scalar type custom attributes, like
reward and terminal, will be considered as a tensor with shape [1,1].
Now that we have a very general transition data structure, which supports storing:
complex state information, such as visual(RGB-D), audio, physical(position, velocity, etc.),
internal states of recurrent networks, etc.
complex action information, whether discreet or contiguous, single space or a combination
of multitude of spaces, by storing them in different keys of the dictionary.
complex reward, whether scalar reward or vectorized reward.
We may use this class to store the transition steps of a full MDP. Transition can
be constructed like:
During Transition instance initialization, tensors stored in major and sub attributes
will be cloned then detached, custom attributes will be deep copied.
Transition also supports Transition.to() method to move
internal tensors to the target pytorch device.
Buffers (replay memory) is one of the core parts of the Machin library. Machin provides
a sophisticated but clear implementation of replay memory, to accommodate the needs
of different frameworks. In The big picture section, we have showed that the
buffer instance encapsulated in the DQN framework has two major APIs: “append” and “sample”,
Append is encapsulated by every framework, in their “store_*” APIs, some frameworks
might will add new attributes to the constructed transition object in there “store_*” APIs,
then call the “append” API of the buffer to add one or more transition objects to the buffer.
There are multiple buffer implementations, the basic Buffer class implements a simple
ring buffer. PrioritizedBuffer extends on the the basic Buffer class with
a prioritized weight tree. Distributed buffers are more interesting and complex because data
are distributed on all process members.
In conclusion, the “append” API just stores one or more transition objects into the buffer,
there are many internal events happening behind the scenes, and you need not worry about them.
What secret actions does this segment of code perform internally? Well, nothing
other than “sampling” and “concatenation”. Argument sample_method indicates
the sample selection method, sample_attrs indicates which attributes of each
sample we are going to acquire, “*” is a wildcard selector picking
up all unspecified attributes.
Then what does “concatenation” mean? To put it simply, it will only affect “major attributes”
and “sub attributes” of each sample, if you have specified additional_concat_attrs, then
custom attributes can also be concatenated into a tensor. We may use a graph to explain this
process happening in the basic Buffer class:
Now that algorithms have got samples from buffers, they can
start training their models. The three types of model free RL algorithms
supported by Machin have three respective internal data path.
For more detailed descriptions of data paths and model requirements of all RL algorithms,
please refer to Algorithm model requirements.
In order to bridge the gap between models and algorithms, Machin uses a function named safe_call()
to pass data from algorithms to your models, and uses different class methods defined in
algorithms like DDPG.action_transform_function() to pack up raw data from your models
before using them in the algorithm framework. With this design, Machin is able to
achieve API consistency between algorithms while maintaining code simplicity.
Again, lets take the classic DQN framework as an example, we will use mode="double"
here, so that a double DQN framework will be initialized, the models used
in the DQN framework are Q networks. Q networks should accept a state and
return value for each possible discreet action, ideally we would like to define the model
according to this description exactly, like the one below, which accepts a single state
argument in its forward() function, and returns a value tensor:
where major attributes like state, action, next_state are dictionaries of tensors,
while sub attributes like reward and terminal are two tensors of shape [batch_size,1],
we will ignore others for now, because if you are not inheriting from the DQN framework and
write your own DQN.reward_func(), others does nothing.
In order to get the target Q value, which is used as an value estimation of the next state, we
must use the Q network / the target Q network to criticize the sampled next_state:
safe_call() is a relatively complex function, it does the following things in general:
Check input & output device of your model, if they are not defined, try to
automatically determine them by checking locations of all parameters.
Check argument names of the forward method of your model, this step will fail
if it is not defined or your model is a JIT model complied by pytorch.
Try to resolve values of arguments by looking them up in the passed dictionaries,
Additional keys in dictionaries that does not belong to args will be ignored.
Therefore, the sampled state must have the required key: “state”, and “state” is the
first argument (exclude self) of QNet.forwrad.
After forwarding, the Q network will pass predicted Q values back to the DQN framework,
and data path is complete, the result Q values of next step will be passed to DQN.reward_func()
to calculate target Q values, and then new values will be used to train the online Q network.
Generally speaking, Just treat all above process as an “advanced kwargs call”,
During sampling, you will interact with your environment, and store some state tensors as values
in a dictionary:
Then during training, you will invoke the update method of your framework, and it will
concatenate states, actions, and next states in the first dimension:
Then states, actions, and next states will be passed to your networks, safely, since
tensors will be automatically moved to your model’s input device, and input device can
be automatically determined or manually specified:
From the perspective of traditional parallel computation, there are many levels
of parallelism, supported by Machin, based on PyTorch, from fine to coarse:
Element level parallelism, based on multidimensional tensor computations.
Task level parallelism, achieved by multi-threading, either provided by
python threads, or the JIT fork mechanism of PyTorch (with no GIL).
Task level parallelism, achieved by multi-processing, either on the same
node, or on different nodes.
For element level parallelism, we can either use existing tensor operators,
or use more flexible operators such as torch.einsum to make customized operators,
or write our own CUDA kernels. We can even use torch.jit to compile our
models and get some performance improvements over plain python APIs. Machin doesn’t
provide any utility in this area.
For based task level parallelism, the basic python libraries, such as
threading and multiprocessing already provide enough functions to achieve the
latter two parallelisms. Machin provides the following enhancements:
Watch for exceptions happening in threads/processes.
Process/Thread pools with local function execution ability, accurate control over tensor serialization policy.
Process/Thread pools with contexts, allow users to pass hard-to-construct objects before executing tasks.
Inter-process queues with accurate control over tensor serialization policy.
Neural network perspective
From the perspective of neural networks, there are some parallelism
paradigms we would like to achieve, with traditional parallel architectures:
Model level parallelism in small batch inference of many small models.
Model level parallelism in large batch inference of one potentially huge model.
Model level “parallelism” in storing an extremely large model across multiple devices or nodes.
Currently, there is no perfect way to deal with the first scenario, because threads
in python are constrained by GIL, while processes are too slow. In MADDPG,
Machin choose to utilize the JIT function provided by pytorch, and use compiled JIT
models to work around the GIL restriction, this method is proved to have about
50% speed advantage over regular thread pools.
The second scenario could be dealt with DistributedDataParallel in PyTorch, by
splitting the large batch into several smaller batches, then perform inference on
different processes asynchronously.
The last scenario is also known as “model sharding”, which means split a huge model
up into several smaller models. It would be more favorable to users if this could be
done automatically by the framework. However, due to the design of PyTorch, where
tensors, not models, are real entities bound to device, it is not possible to achieve
this function directly, with PyTorch, as of version 1.5.0. Machin currently does not
provide automatic model sharding as well, but our internal implementation do support
implementing such a feature, this feature might will be added in the future. Currently,
Machin only provides automatic assignment of (splitted) models, with ModelAssigner.
Reinforcement learning perspective
When it comes to RL algorithms, these parallelisms are usually required:
Environment parallelism, where multiple same environments are executed synchronously in parallel, to produce larger batches of observations.
Agent parallelism, where multiple agents are learning synchronously or asynchronously, like A3C, DQNApex.
Agent parallelism in multi-agent algorithms, where multiple agents of different types are learning synchronously or asynchronously, like MADDPG
Machin provides parallel environment wrappers for the first scenario, like openai_gym.ParallelWrapperSubProc, which starts
multiple worker processes, create an environment instance in each worker, then send commands and receive responses in batches.
The second scenario is more tricky, since agents are usually distributed across
“parataxis” (same-level) processes, and on multiple nodes rather than “hypotaxis”
sub-processes started in a process pool, on the same node. We will discuss this
part in the Distributed section.
The third scenario depends on the RL algorithm framework, for MADDPG, each agent corresponds
to a pair of separate actor and critic, in this case, only task level parallelism based threads could
be used to solve the problem, because it is hard to create batches, caused by parameter and model architecture
difference. But if we are using single agent RL algorithms such as DDPG to train a group of
homogeneous agents, then batching is preferred due its efficientcy.
Distributed is awesome, as well as extremely painful to deal with, hard to design,
and even harder to debug, because applications are often required to have some
crucial features like consistency, availability, partition-tolerance, and good performance.
Currently, since Machin relies on the PyTorch RPC framework, it does not provide
any distribute mechanism able to guarantee any part of
consistency, availability or partition-tolerance, due to some limitations in
the PyTorch RPC framework, as of version 1.5.0.
What Machin provide is a more advanced set of RPC APIs: an implementation of RPC groups (namespace), on which you can
register a service with register or share a resource with pair, like the code below:
This “DNS” like mechanism enables Machin to abstract away “name”s of processes, and a specific server process,
instead, every process who wants to access the service/resource are faced with a registration
table. This table could be different, depending on the actual process running the service,
and the internal implementation of the service. With this design, Machin is able to provide
some general distributed implementations such as DistributedBuffer, DistributedPrioritizedBuffer,
PushPullGradServer, etc.
Apart from this, Machin just provides a thin layer of incapsulation over the somewhat complex
APIs of torch.distributed and torch.distributed.rpc, to make them less confusing.
In order to fully understand all the functions provided machin.parallel, we should
read some detailed use cases, this part requires proficiency with but not a deep understanding of:
threading library of python
multiprocessing library of python
torch.distributed module
torch.distributed.rpc module
If below examples are not enough for you, please refer to tests
A3C is basically a bunch of A2C agents with a gradient reduction server. A3C(A2C)
agents will interact with their environment simulators, train their local actors
and critics, then push gradients to the gradient reduction server, the gradient
reduction server will apply reduced gradients to its internal models (managed actor
and critic network), then push the updated parameters to a key-value server. Agents
will be able to pull the newest parameters and continue updating.
All A3C agents are fully asynchronous, gradient pushing & parameter pulling are asynchronous
as well.
We will use the “CartPole-v0” environment from OpenAI Gym as an example, the actor network
and critic network are as follows:
In order to initialize the A3C framework, we need to first initialize the
distributed world:
# initlize distributed world first_world=World(world_size=3,rank=rank,name=str(rank),rpc_timeout=20)
then provide a PushPullGradServer to it, Machin provides some helpful utility functions
to aid inexperienced users initialize the distributed environment easily:
# initlize distributed world first_world=World(world_size=3,rank=rank,name=str(rank),rpc_timeout=20)actor=Actor(observe_dim,action_num)critic=Critic(observe_dim)# in all test scenarios, all processes will be used as reducersservers=grad_server_helper([lambda:Actor(observe_dim,action_num),lambda:Critic(observe_dim)],learning_rate=5e-3)a3c=A3C(actor,critic,nn.MSELoss(reduction='sum'),servers)
And start training, just as the A2C algorithm:
# manually control syncing to improve performancea3c.set_sync(False)# begin trainingepisode,step,reward_fulfilled=0,0,0smoothed_total_reward=0terminal=Falsewhileepisode<max_episodes:episode+=1total_reward=0terminal=Falsestep=0state=t.tensor(env.reset(),dtype=t.float32).view(1,observe_dim)# manually pull the newest parametersa3c.manual_sync()tmp_observations=[]whilenotterminalandstep<=max_steps:step+=1witht.no_grad():old_state=state# agent model inferenceaction=a3c.act({"state":old_state})[0]state,reward,terminal,_=env.step(action.item())state=t.tensor(state,dtype=t.float32).view(1,observe_dim)total_reward+=rewardtmp_observations.append({"state":{"state":old_state},"action":{"action":action},"next_state":{"state":state},"reward":reward,"terminal":terminalorstep==max_steps})# updatea3c.store_episode(tmp_observations)a3c.update()# show rewardsmoothed_total_reward=(smoothed_total_reward*0.9+total_reward*0.1)logger.info("Process {} Episode {} total reward={:.2f}".format(rank,episode,smoothed_total_reward))ifsmoothed_total_reward>solved_reward:reward_fulfilled+=1ifreward_fulfilled>=solved_repeat:logger.info("Environment solved!")# will cause torch RPC to complain# since other processes may have not finished yet.# just for demonstration.exit(0)else:reward_fulfilled=0
A3C agents should will be successfully trained within about 1500 episodes,
they converge much slower than A2C agents:
[2020-07-31 00:21:37,690] <INFO>:default_logger:Process 1 Episode 1346 total reward=184.91
[2020-07-31 00:21:37,723] <INFO>:default_logger:Process 0 Episode 1366 total reward=171.22
[2020-07-31 00:21:37,813] <INFO>:default_logger:Process 2 Episode 1345 total reward=190.73
[2020-07-31 00:21:37,903] <INFO>:default_logger:Process 1 Episode 1347 total reward=186.41
[2020-07-31 00:21:37,928] <INFO>:default_logger:Process 0 Episode 1367 total reward=174.10
[2020-07-31 00:21:38,000] <INFO>:default_logger:Process 2 Episode 1346 total reward=191.66
[2020-07-31 00:21:38,000] <INFO>:default_logger:Environment solved!
DQNApex and DDPGApex are actually based on the same architecture, therefore
in this section, we are going to take DQNApex as an example, its distributed architecture
could be described in the following graph:
The Apex architecture decouples the sampling and updating process with the
prioritized replay buffer. There could be several implementations, such as:
using a central replay buffer on a single process
using a distributed buffer, with a central stopping signal.
using a distributed buffer, each buffer with a separate lock.
Machin choose the third implementation because it is most efficient:
#1 is slow because each appending requires a RPC process to update the global weight tree,
and it also doesn’t scale when the number of workers(samplers) grows too large, such as
100+ workers.
The central lock used in #2 is meant to protect the importance sampling-updating process,
so each buffer maintains a local weight tree, during sampling, the learner will signal “STOP” to all workers,
and signal “START” to all workers when importance weight update is completed, however, this design does
not truly decouples learning and sampling, therefore most of the time workers are just hanging and wait for
the learner to complete updaing.
#3 design is the best, because each append operation is completely local (only needs to acquire
a local lock), and global sampling is complete decoupled from local appending (because lock are
immediately released after returning sampled data, and not till update complete) as show in
figure Fig. 7.
However, it could be very tricky to implement this process, because appending is still happening
after sampling and before the learner finishes updating importance sampling weights (is-weights),
therefore Machin uses a taint table, which is essentially a table full of auto increment counters,
each counter maps to an entry slot in the lower ring buffer, and is incremented if the entry has been
replaced with new entries. This replacement should will not be very often if the buffer has enough
space, (100000+), therefore guarantee the correctness of importance weight update.
There is one thing to note, it could be indefinitely long for learner to calculate
the virtual global weight sum tree using the root node of all local weight sum trees as leaves,
therefore at the time of sampling, the used weight sum of local trees is already outdated, and sampling
probability of each tree should have changed. However, if the size of each local buffer is large enough,
then the ratio of difference between the old collected local weight sums and current weight sums should be
acceptable.
Now that we know how the Apex framework is designed, we may try an example. We will use the “CartPole-v0”
environment from OpenAI Gym as an example, the Q networks is as follows:
Because apex frameworks relies on the DistributedPrioritizedBuffer, the learner needs to
know the position and service name of each local buffer, as show in figure Fig. 7,
in order to initialize the Apex framework, we need to provide a RPC process group,
where all learner(s) and workers will live on:
And we will also provide a model server on which learner(s) will store the newest parameters and
workers will pull the newest parameters from the server. This kind of parameter server is different
from the PushPullGradServer used above, and we will name it as PushPullModelServer,
Currently, each PushPullModelServeronly manages one model per server instance,
and since there is only one model needs to be shared in DQN (the online Q network), we only need one
model server instance:
The tasks of learner(s) and workers are quite a bit different, since learner(s) only needs to update
their internal models repeatedly, using samples from workers’ buffers, and workers only need to
do update-sample-update-sample…, they will run different branches in the main program.
Maybe you want to ask, why are we using learner(s), isn’t the original essay
stating that there is only one learner and multiple workers? The answer is: Machin supports using DistributedDataParallel
(DataParallel is also supported)
from PyTorch inside DQNApex, so that you may distribute the updating task across multiple learner processes, if your models
is way too large to be computed by a single process. It is not sensible to using this technique with small models,
but for pure demonstration purpose, we will use it here:
ifrankin(2,3):# learner_group.group is the wrapped torch.distributed.ProcessGrouplearner_group=world.create_collective_group(ranks=[2,3])# wrap the model with DistributedDataParallel# if current process is learner process 2 or 3q_net=DistributedDataParallel(module=QNet(observe_dim,action_num),process_group=learner_group.group)q_net_t=DistributedDataParallel(module=QNet(observe_dim,action_num),process_group=learner_group.group)else:q_net=QNet(observe_dim,action_num)q_net_t=QNet(observe_dim,action_num)# we may use a smaller batch size to train if we are using# DistributedDataParalleldqn_apex=DQNApex(q_net,q_net_t,t.optim.Adam,nn.MSELoss(reduction='sum'),apex_group,(servers[0],),batch_size=50)
The main part of the training process is as follows:
# synchronize all processes in the group, make sure# distributed buffer has been created on all processes# in apex_groupapex_group.barrier()# manually control syncing to improve performancedqn_apex.set_sync(False)ifrankin(0,1):# Process 0 and 1 are workers(samplers)# begin trainingepisode,step,reward_fulfilled=0,0,0smoothed_total_reward=0whileepisode<max_episodes:# sleep to wait for learners keep upsleep(0.1)episode+=1total_reward=0terminal=Falsestep=0state=t.tensor(env.reset(),dtype=t.float32).view(1,observe_dim)# manually pull the newest parametersdqn_apex.manual_sync()whilenotterminalandstep<=max_steps:step+=1witht.no_grad():old_state=state# agent model inferenceaction=dqn_apex.act_discrete_with_noise({"state":old_state})state,reward,terminal,_=env.step(action.item())state=t.tensor(state,dtype=t.float32)\
.view(1,observe_dim)total_reward+=rewarddqn_apex.store_transition({"state":{"state":old_state},"action":{"action":action},"next_state":{"state":state},"reward":reward,"terminal":terminalorstep==max_steps})smoothed_total_reward=(smoothed_total_reward*0.9+total_reward*0.1)logger.info("Process {} Episode {} total reward={:.2f}".format(rank,episode,smoothed_total_reward))ifsmoothed_total_reward>solved_reward:reward_fulfilled+=1ifreward_fulfilled>=solved_repeat:logger.info("Environment solved!")# will cause torch RPC to complain# since other processes may have not finished yet.# just for demonstration.exit(0)else:reward_fulfilled=0elifrankin(2,3):# wait for enough sampleswhiledqn_apex.replay_buffer.all_size()<500:sleep(0.1)whileTrue:dqn_apex.update()
Result:
[2020-08-01 12:51:04,323] <INFO>:default_logger:Process 1 Episode 756 total reward=192.42
[2020-08-01 12:51:04,335] <INFO>:default_logger:Process 0 Episode 738 total reward=187.58
[2020-08-01 12:51:04,557] <INFO>:default_logger:Process 1 Episode 757 total reward=193.17
[2020-08-01 12:51:04,603] <INFO>:default_logger:Process 0 Episode 739 total reward=188.72
[2020-08-01 12:51:04,789] <INFO>:default_logger:Process 1 Episode 758 total reward=193.86
[2020-08-01 12:51:04,789] <INFO>:default_logger:Environment solved!
The IMPALA algorithm has the same parallel architecture as DQNApex and DDPGApex do,
the only difference is that the internal distributed buffer it is using is a simple distributed buffer, with no
distributed prioritized tree:
In order to initialize the IMPALA framework, we need to pass two accessors to two individual
PushPullModelServer to the framework, and take note that IMPALAdoes not support
storing a single step**, since v-trace calculation requires sampling complete episodes, all transition
objects in the IMPALA buffer are episodes rather than steps, therefore the used batch_size
is set to 2, which is much smaller then 50 used in DQNApex:
ifrankin(2,3):# learner_group.group is the wrapped torch.distributed.ProcessGrouplearner_group=world.create_collective_group(ranks=[2,3])# wrap the model with DistributedDataParallel# if current process is learner process 2 or 3actor=DistributedDataParallel(module=Actor(observe_dim,action_num),process_group=learner_group.group)critic=DistributedDataParallel(module=Critic(observe_dim),process_group=learner_group.group)else:actor=Actor(observe_dim,action_num)critic=Critic(observe_dim)# we may use a smaller batch size to train if we are using# DistributedDataParallel# note: since the impala framework is storing a whole# episode as a single sample, we should wait for a smaller numberimpala=IMPALA(actor,critic,t.optim.Adam,nn.MSELoss(reduction='sum'),impala_group,servers,batch_size=2)
The main part of the training process is almost the same as that of DQNApex:
# synchronize all processes in the group, make sure# distributed buffer has been created on all processes in apex_groupimpala_group.barrier()# manually control syncing to improve performanceimpala.set_sync(False)ifrankin(0,1):# Process 0 and 1 are workers(samplers)# begin trainingepisode,step,reward_fulfilled=0,0,0smoothed_total_reward=0whileepisode<max_episodes:# sleep to wait for learners keep upsleep(0.1)episode+=1total_reward=0terminal=Falsestep=0state=t.tensor(env.reset(),dtype=t.float32).view(1,observe_dim)# manually pull the newest parametersimpala.manual_sync()tmp_observations=[]whilenotterminalandstep<=max_steps:step+=1witht.no_grad():old_state=state# agent model inferenceaction,action_log_prob,*_= \
impala.act({"state":old_state})state,reward,terminal,_=env.step(action.item())state=t.tensor(state,dtype=t.float32) \
.view(1,observe_dim)total_reward+=rewardtmp_observations.append({"state":{"state":old_state},"action":{"action":action},"next_state":{"state":state},"reward":reward,"action_log_prob":action_log_prob.item(),"terminal":terminalorstep==max_steps})impala.store_episode(tmp_observations)smoothed_total_reward=(smoothed_total_reward*0.9+total_reward*0.1)logger.info("Process {} Episode {} total reward={:.2f}".format(rank,episode,smoothed_total_reward))ifsmoothed_total_reward>solved_reward:reward_fulfilled+=1ifreward_fulfilled>=solved_repeat:logger.info("Environment solved!")# will cause torch RPC to complain# since other processes may have not finished yet.# just for demonstration.exit(0)else:reward_fulfilled=0elifrankin(2,3):# wait for enough samples# note: since the impala framework is storing a whole# episode as a single sample, we should wait for a smaller numberwhileimpala.replay_buffer.all_size()<5:sleep(0.1)whileTrue:impala.update()
IMPALA converges very fast, usually within 150 episodes:
[2020-08-01 23:25:34,861] <INFO>:default_logger:Process 1 Episode 72 total reward=185.32
[2020-08-01 23:25:35,057] <INFO>:default_logger:Process 1 Episode 73 total reward=186.79
[2020-08-01 23:25:35,060] <INFO>:default_logger:Process 0 Episode 70 total reward=193.28
[2020-08-01 23:25:35,257] <INFO>:default_logger:Process 1 Episode 74 total reward=188.11
[2020-08-01 23:25:35,261] <INFO>:default_logger:Process 0 Episode 71 total reward=193.95
[2020-08-01 23:25:35,261] <INFO>:default_logger:Environment solved!
In this tutorial, we are going to try and implement the recurrent architecture in DQN and PPO architecture, the original architecture
of “DRQN” was described in Deep Recurrent Q-Learning for Partially Observable MDPs, for the sake of
simplicity and this tutorial will discard the CNN part used to process Atari game screens, instead, we will directly access the internal
128 bytes of RAM of tested Atari games.
Now, in order to implement the recurrent architecture, we should have a solid grasp of
the following related aspects in advance:
Recurrent networks were introduced into the reinforcement learning field to deal with POMDP models, in which
agents are not able to observe the full state of the environment, and they have to rely on their internal
memories of their past observations. The essay used Atari games as
the benchmark suite, they compared DRQN with DQN in multiple scenarios, and shows that DRQN has significant
advantage over DQN in the frostbite
game, while performing about as good as / fail to compete with DQN in many other Atari games.
For offline reinforcement learning frameworks relying on the “replay memory”, like DQN and DDPG, the tricky bit
is that by the time of sampling, the trained models (online network and target network) are already different from the model
used to interact with the environment and produce samples, authors of the essay
suggested two ways of updating, both ways requires to provide a contiguous period of samples to the network to compute hidden
states, and back propagation through time.
For online reinforcement learning frameworks such as A2C and PPO with no replaying mechanism, there is no need
to specifically recalculate hidden states, because by the time of training, the stored samples are still generated by a actor network
equal to (when update iteration=0)/ very close to (when update iteration > 0) the trained network. Therefore, hidden states can be
stored along with other observations,
We are going to show the detailed recurrent implementations in the above two reinforcement learning categories, using DQNPer
and PPO respectively.
Compared to the implementation provided in this repo,
our implementation of DRQN is significantly more inefficient, and potentially has different result because:
Duplicate states are stored for (history_depth - 1) times.
Only the last step in the bootstrapped random updates is performed, Q values evaluated in previous steps are not used.
You may implement your own framework to overcome these shortcomings using the utilities provided by Machin.
Authors of the original paper choose to train the LSTM layer along with the CNN layers,
in order to deal with the “hidden state” input of the LSTM layer, they proposed two methods:
Bootstrapped sequential updates
Bootstrapped random updates
“Sequential updates” use the recurrent Q network to train through a whole episode, then BPTT (back propagate through time).
“Random updates” samples a random period of length “unrolled_time_steps” instead of a whole episode, other details are the same.
In order to achieve this with the DQNPer framework, we will have to store the history observations for each transition, since
the internal replay buffer does not store episodic boundaries between transitions:
Then we will also have to define two branches inside the forward function of our recurrent Q network, one branch
for normal action sampling and another branch for training:
defforward(self,mem=None,hidden=None,history_mem=None):ifmemisnotNone:# use `mem`, `hidden`, in sampling...else:# use `history_mem`, in updating...
We will show the details in the implementation section of this tutorial.
However, not using BPTT will lose most benefits of recurrence, if you would like to use this method, then you need to implement
your own framework sampling entire episodes and not timesteps from the replay buffer. Then zero-pad the sampled episodes so they
are all the same length. Finally let your recurrent network go through the sampled episodes and calculate log probs/actions/hidden states.
You may refer to this repo for more information.
We are going to design a History class which allow users to store new states by append() and returns a
fixed-length trajectory by get(), if there are not enough states to form a complete trajectory, then zero
will be used to form paddings:
In order to provide sampled trajectories to the network, we just need to store “history” instead of “state”:
whilenotterminal:step+=1witht.no_grad():history.append(state)# agent model inferenceaction=dqn.act_discrete_with_noise({"mem":history.get()})# info is {"ale.lives": self.ale.lives()}, not used herestate,reward,terminal,_=env.step(action.item())state=convert(state)total_reward+=rewardold_history=history.get()new_history=history.append(state).get()dqn.store_transition({"state":{"mem":old_history},"action":{"action":action},"next_state":{"mem":new_history},"reward":reward,"terminal":terminal})
classRecurrentQNet(nn.Module):def__init__(self,action_num):super(RecurrentQNet,self).__init__()self.gru=nn.GRU(128,256,batch_first=True)self.fc1=nn.Linear(256,256)self.fc2=nn.Linear(256,action_num)defforward(self,mem=None,hidden=None,history_mem=None):ifmemisnotNone:# in samplinga,h=self.gru(mem.unsqueeze(1),hidden)returnself.fc2(t.relu(self.fc1(t.relu(a.flatten(start_dim=1))))),helse:# in updatingbatch_size=history_mem.shape[0]seq_length=history_mem.shape[1]hidden=t.zeros([1,batch_size,256],device=history_mem.device)foriinrange(seq_length):_,hidden=self.gru(history_mem[:,i].unsqueeze(1),hidden)# a[:, -1] = hreturnself.fc2(t.relu(self.fc1(t.relu(hidden.transpose(0,1).flatten(start_dim=1)))))
As you can see, the forward method is divided into two parts, the first part is for normal acting,
where users will pass hidden states to the network manually and get actions during sampling:
hidden=t.zeros([1,1,256])state=convert(env.reset())history=History(history_depth,(1,128))whilenotterminal:step+=1witht.no_grad():old_state=statehistory.append(state)# agent model inferenceaction,hidden=drqn.act_discrete_with_noise({"mem":old_state,"hidden":hidden})
The second part is used during updating, where the DQNPer framework will provide a batch of
trajectories to the network and get Q value tensor for last state in each trajectory:
Note: These test results are put here for pure demonstration purpose, they are not intended for
statistical comparision.
It seems that the DRQN implementation is extremely unstable, DQN is not quite stable as well,
especially when history_depth > 1. PPO learns a little bit better than DQN when history_depth = 1,
but it is able to cross the 300 boundary when history_depth = 4,
RPPO is also able to overcome the 300 boundary after 6000 episodes. Since learning
rate is fine tuned, performance of all frameworks drop considerably after some point.
Currently machin.env has two sub modules: machin.env.utils and
machin.env.wrappers.
The submodule machin.env.utils of the environment module provides
some convenient utility functions you might will need in your own application,
such as disabling the rendering window while keeping the rendered result in OpenAI gym.
The submodule machin.env.wrappers provides process-level parallel environment
wrappers for different environments.
machin.model is a collection of popular network models you might will use in your own
program, for example, ResNet.
Model module also contains the basis of all network modules: NeuralNetworkModule,
this wrapper is built upon regular torch.nn.Module, and allows users to specify input/output
sub module.
We will use some basic symbols to simplify the description:
... means one or more dimensions, with non-zero sizes.
<> means optional results / arguments. <...> means any number of optional results / arguments.
Note: When an algorithm API returns one result, the result will not be wrapped in a tuple, when it returns multiple results, results will be wrapped in a tuple. This design is made to support:
# your Q network model only returns a Q value tensoract=dqn.act({"state":some_state})# your Q network model returns Q value tensor with some additional hidden statesact,h=dqn.act({"state":some_state})
Users will invoke the “act*” api provided by the framework during sampling,
to let their models produce an action with respect to their state input,
“*” indicates additional extensions such as “_with_noise”, “_discreet”, etc.
depending on the implementation and type of the RL framework.
Below is a list of supported acting APIs of different frameworks:
store_transition api is now deprecated, please use store_episode
only.
Algorithms generally encapsulate a replay buffer inside, the replay buffer is not
necessarily a “real” replay buffer. For online algorithms such as A2C and PPO with
no replaying mechanisms, the replay buffer is used as a place to put all of the
samples, and is cleared after every training/update step:
# sample a batchbatch_size,(state,action,reward,next_state,terminal,target_value,advantage)= \
self.replay_buffer.sample_batch(-1,sample_method="all",...)...self.replay_buffer.clear()
All frameworks use the same store_episode API to store a full episode into
the replay buffer:
All frameworks supports the update function, but the keyword arguments
of the update function might be a little bit different. For example, DDPG
allows you to choose update actor/critic/their targets, individually, while
DQN only supports choose to update Q network/its target individually.
Moreover, the update function of offline algorithms such as DDPG and online
algorithms such as A2C and PPO are different. Because A2C and PPO will not
update on outdated samples, their update function contains an internal
update loop, therefore you should not call them many times:
# DDPG update:ifepisode>100:foriinrange(step.get()):ddpg.update()# PPO update:# update() already contains a loopppo.store_episode(tmp_observations)ppo.update()
and their update will also clear the internal replay buffer
every time. So you are recommended to read the implementation of your
selected algorithm before using it somewhere.
All frameworks provide this pair of APIs, for saving and loading models passed
to the algorithm. Internally, the models passed to the algorithm framework will
become a member of the framework instance, for example:
dqn=DQN(q_net,q_net_t,t.optim.Adam,nn.MSELoss(reduction='sum'))# you may access q_net and q_net_t with:print(dqn.qnet)print(dqn.qnet_target)
You can print the _is_restorable attribute of the algorithm class to view
models saved/loaded internally, and print the _is_top attribute of the algorithm
class to view top level models, like Q network, actor network, critic network, etc.:
Saving/Loading API requires you to provide a directory to save/load the models,
an optional model name map to specify the mapping relation between “model <-> saved model name”,
and an optional version number indicating the version of save:
# Model dqn.qnet_target will be saved **as a whole** in "./qnt_1000.pt"# **saved as whole** means saving like: torch.save(dqn.qnet_target, ...)dqn.save("./",network_map={"qnet_target":"qnt"},version=1000)# If no name mapping is specified, the default "qnet_target" will be used# as the saving namedqn.save("./",version=1000)# If no version is specified, the default saving version number is 0dqn.save("./",network_map={"qnet_target":"qnt"})# If no version number is specified, then the model with the largest version# number will be loadeddqn.load("./",network_map={"qnet_target":"qnt"})# Or specify a specific version to loaddqn.load("./",network_map={"qnet_target":"qnt"},version=1000)# An invalid version will cause the framework to find the latest available versiondqn.load("./",network_map={"qnet_target":"qnt"},version=10000)# If you have a file named "qnt.pt", which has no valid version number, it# will be ignored.
You may move the saved model files to a different machine with different devices,
there is no need to worry about different device mapping, the parameters of saved models
will be loaded into your model(s) passed to the algorithm framework.
Some frameworks may need to save multiple models, for example, DDPG needs to
save a target critic network and a target actor network, in this case, each model will
be saved to a separate file, the loading function will try to find the maximum available
version in the valid version intersection of all models:
# suppose there are these models in the target directory:# actor_target_0.pt, actor_target_100.pt, actor_target_1000.pt# critic_target_0.pt, critic_target_100.pt# then version 100 will be loadedddpg.load("./")
Since algorithms are drastically different, it is hard to conform some of their
features to the same style and design, therefore, they are exposed as-is if you
would like to interface with these APIs, for using the critic network, evaluating
an action, etc. Below is a list of these APIs supported by different frameworks:
Machin relies on the correct model implementation to function correctly,
different RL algorithms may need drastically dissimilar models. Therefore,
in this section, we are going to outline the detailed requirements on models
of different frameworks.
We will use some basic symbols to simplify the model signature:
abc_0[*] means a tensor with meaning “abc”, and has index 0 in all argument tensors with the same meaning, “*” is a wildcard which accepts one or more non-zero dimensions, valid examples are:
state_0[batch_size, 1]
state_1[1, 2, 3, 4, 5]
state_2[…]
... means one or more arguments (tensors/not tensors), or one or more dimensions, with non-zero sizes.
<> means optional results / arguments. <...> means any number of optional results / arguments.
Note: When an algorithm API returns one result, the result will not be wrapped in a tuple, when it returns multiple results, results will be wrapped in a tuple. This design is made to support:
# your Q network model only returns a Q value tensoract=dqn.act({"state":some_state})# your Q network model returns Q value tensor with some additional hidden statesact,h=dqn.act({"state":some_state})
Note: the forward method signature
must conform to the following definitions exactly,
with no more or less arguments/keyword arguments.
Note: the requirements in this document does not apply to the conditions
where: (1) you have made a custom implementation (2) you have inherited frameworks
and customized their result adaptors like DDPG.action_transform_function(),
etc.
Actor(state_0[batch_size,...],...,state_n[batch_size,...])->action[batch_size,...],<...># if contiguous->action[batch_size,action_num],<...># if discreetCritic(state_0[batch_size,...],...,state_n[batch_size,...],action[batch_size,.../action_num])->q_value[batch_size,1],<...>
where:
action_num is the number of available discreet actions
sum(action[i,:])==1 if discreet.
Example:
classActor(nn.Module):def__init__(self,state_dim,action_dim,action_range):super(Actor,self).__init__()self.fc1=nn.Linear(state_dim,16)self.fc2=nn.Linear(16,16)self.fc3=nn.Linear(16,action_dim)self.action_range=action_rangedefforward(self,state):a=t.relu(self.fc1(state))a=t.relu(self.fc2(a))a=t.tanh(self.fc3(a))*self.action_rangereturnaclassActorDiscrete(nn.Module):def__init__(self,state_dim,action_dim):# action_dim means action_num heresuper(ActorDiscrete,self).__init__()self.fc1=nn.Linear(state_dim,16)self.fc2=nn.Linear(16,16)self.fc3=nn.Linear(16,action_dim)defforward(self,state):a=t.relu(self.fc1(state))a=t.relu(self.fc2(a))a=t.softmax(self.fc3(a),dim=1)returnaclassCritic(nn.Module):def__init__(self,state_dim,action_dim):super(Critic,self).__init__()self.fc1=nn.Linear(state_dim+action_dim,16)self.fc2=nn.Linear(16,16)self.fc3=nn.Linear(16,1)defforward(self,state,action):state_action=t.cat([state,action],1)q=t.relu(self.fc1(state_action))q=t.relu(self.fc2(q))q=self.fc3(q)returnq
action can be sampled from pytorch distributions using non-differentiable sample().
action_log_prob is the log likelihood of the sampled action, must be differentiable.
distribution_entropy is the entropy value of reparameterized distribution, must be differentiable.
Actor must calculate the log probability of the input action if it is not None, and return the input action as-is.
Example:
classActor(nn.Module):def__init__(self,state_dim,action_num):super(Actor,self).__init__()self.fc1=nn.Linear(state_dim,16)self.fc2=nn.Linear(16,16)self.fc3=nn.Linear(16,action_num)defforward(self,state,action=None):a=t.relu(self.fc1(state))a=t.relu(self.fc2(a))probs=t.softmax(self.fc3(a),dim=1)dist=Categorical(probs=probs)act=(actionifactionisnotNoneelsedist.sample())act_entropy=dist.entropy()act_log_prob=dist.log_prob(act.flatten())returnact,act_log_prob,act_entropyclassActorContinuous(nn.Module):def__init__(self,state_dim,action_dim,action_range):super(Actor,self).__init__()self.fc1=nn.Linear(state_dim,16)self.fc2=nn.Linear(16,16)self.mu_head=nn.Linear(16,action_dim)self.sigma_head=nn.Linear(16,action_dim)self.action_range=action_rangedefforward(self,state,action=None):a=t.relu(self.fc1(state))a=t.relu(self.fc2(a))mu=self.mu_head(a)sigma=softplus(self.sigma_head(a))dist=Normal(mu,sigma)act=(actionifactionisnotNoneelsedist.sample())act_entropy=dist.entropy().sum(1,keepdim=True)# If your distribution is different from "Normal" then you may either:# 1. deduce the remapping function for your distribution and clamping# function such as tanh# 2. clamp you action, but please take care:# 1. do not clamp actions before calculating their log probability,# because the log probability of clamped actions might will be# extremely small, and will cause nan# 2. do not clamp actions after sampling and before storing them in# the replay buffer, because during update, log probability will# be re-evaluated they might also be extremely small, and network# will "nan". (might happen in PPO, not in SAC because there is# no re-evaluation)# Only clamp actions sent to the environment, this is equivalent to# change the action reward distribution, will not cause "nan", but# this makes your training environment further differ from you real# environment.# the suggested way to confine your actions within a valid range# is not clamping, but remapping the distribution# from the SAC essay: https://arxiv.org/abs/1801.01290act_log_prob=dist.log_prob(act)act_tanh=t.tanh(act)act=act_tanh*self.action_range# the distribution remapping process used in the original essay.act_log_prob-=t.log(self.action_range*(1-act_tanh.pow(2))+1e-6)act_log_prob=act_log_prob.sum(1,keepdim=True)returnact,act_log_prob,act_entropyclassCritic(nn.Module):def__init__(self,state_dim):super(Critic,self).__init__()self.fc1=nn.Linear(state_dim,16)self.fc2=nn.Linear(16,16)self.fc3=nn.Linear(16,1)defforward(self,state):v=t.relu(self.fc1(state))v=t.relu(self.fc2(v))v=self.fc3(v)returnv
For SAC, Machin expects an actor similar to the actors in stochastic
policy gradient methods such as A2C, and multiple critics similar to critics
used in DDPG:
action can only be sampled from pytorch distributions using differentiablersample().
action_log_prob is the log likelihood of the sampled action, must be differentiable.
distribution_entropy is the entropy value of reparameterized distribution, must be differentiable.
Actor must calculate the log probability of the input action if it is not None, and return the input action as-is.
Example:
classActor(nn.Module):def__init__(self,state_dim,action_dim,action_range):super(Actor,self).__init__()self.fc1=nn.Linear(state_dim,16)self.fc2=nn.Linear(16,16)self.mu_head=nn.Linear(16,action_dim)self.sigma_head=nn.Linear(16,action_dim)self.action_range=action_rangedefforward(self,state,action=None):a=t.relu(self.fc1(state))a=t.relu(self.fc2(a))mu=self.mu_head(a)sigma=softplus(self.sigma_head(a))dist=Normal(mu,sigma)act=(actionifactionisnotNoneelsedist.rsample())act_entropy=dist.entropy().sum(1,keepdim=True)# the suggested way to confine your actions within a valid range# is not clamping, but remapping the distributionact_log_prob=dist.log_prob(act)act_tanh=t.tanh(act)act=act_tanh*self.action_range# the distribution remapping process used in the original essay.act_log_prob-=t.log(self.action_range*(1-act_tanh.pow(2))+1e-6)act_log_prob=act_log_prob.sum(1,keepdim=True)returnact,act_log_prob,act_entropyclassCritic(nn.Module):def__init__(self,state_dim,action_dim):super(Critic,self).__init__()self.fc1=nn.Linear(state_dim+action_dim,16)self.fc2=nn.Linear(16,16)self.fc3=nn.Linear(16,1)defforward(self,state,action):state_action=t.cat([state,action],1)q=t.relu(self.fc1(state_action))q=t.relu(self.fc2(q))q=self.fc3(q)returnq
machin.env.wrappers provides parallel execution wrappers for various
environments.
class machin.env.wrappers.base.ParallelWrapperBase(*_, **__)[source]¶
Bases: abc.ABC
Note
Parallel wrapper is designed to wrap the same kind of environments,
they may have different parameters, but must have the same action
and observation space.
Dummy parallel wrapper for gym environments, implemented using for-loop.
For debug purpose only.
Parameters
env_creators (List[Callable[[int], gym.core.Env]]) – List of gym environment creators, used to create
environments, accepts a index as your environment id.
Let specified environment(s) run one time step. Specified environments
must be active and have not reached terminal states before.
Parameters
action (Union[numpy.ndarray, List[Any]]) – Actions sent to each specified environment, the size of the
first dimension must match the number of selected environments.
idx (Union[int, List[int]]) – Indexes of selected environments, default is all.
Returns
Observation, reward, terminal, and diagnostic info.
seed (Union[int, List[int]]) – If seed is int, the same seed will be used for all
environments.
If seed is List[int], it must have the same size as
the number of all environments.
If seed is None, all environments will use the default
seed.
env_creators (List[Callable[[int], gym.core.Env]]) – List of gym environment creators, used to create
environments on sub process workers, accepts a index as your
environment id.
Let specified environment(s) run one time step. Specified environments
must be active and have not reached terminal states before.
Parameters
action (Union[numpy.ndarray, List[Any]]) – Actions sent to each specified environment, the size of the
first dimension must match the number of selected environments.
idx (Union[int, List[int]]) – Indexes of selected environments, default is all.
Returns
Observation, reward, terminal, and diagnostic info.
seed (Union[int, List[int]]) – If seed is int, the same seed will be used for all
environments.
If seed is List[int], it must have the same size as
the number of all environments.
If seed is None, all environments will use the default
seed.
DDPG supports two ways of updating the target network, the first
way is polyak update (soft update), which updates the target network
in every training step by mixing its weights with the online network
using update_rate.
The other way is hard update, which copies weights of the online
network after every update_steps training step.
You can either specify update_rate or update_steps to select
one update scheme, if both are specified, an error will be raised.
These two different update schemes may result in different training
stability.
Parameters
actor (Union[machin.model.nets.base.NeuralNetworkModule, torch.nn.modules.module.Module]) – Actor network module.
actor_target (Union[machin.model.nets.base.NeuralNetworkModule, torch.nn.modules.module.Module]) – Target actor network module.
Use actor network to produce a discrete action for the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
Returns
Action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num],
produced by your actor.
Any other things returned by your Q network. if they exist.
Use actor network to produce a noisy discrete action for
the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
choose_max_prob (float) – Probability to choose the largest component when actor
is outputing extreme probability vector like [0,1,0,0].
Returns
Noisy action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num].
Any other things returned by your Q network. if they exist.
Use actor network to produce a discrete action for the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
Returns
Action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num],
produced by your actor.
Any other things returned by your Q network. if they exist.
Use actor network to produce a noisy discrete action for
the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
choose_max_prob (float) – Probability to choose the largest component when actor
is outputing extreme probability vector like [0,1,0,0].
Returns
Noisy action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num].
Any other things returned by your Q network. if they exist.
Your criterion must return a tensor of shape [batch_size,1]
when given two tensors of shape [batch_size,1], since we
need to multiply the loss with importance sampling weight
element-wise.
If you are using loss modules given by pytorch. It is always
safe to use them without any modification.
DDPG supports two ways of updating the target network, the first
way is polyak update (soft update), which updates the target network
in every training step by mixing its weights with the online network
using update_rate.
The other way is hard update, which copies weights of the online
network after every update_steps training step.
You can either specify update_rate or update_steps to select
one update scheme, if both are specified, an error will be raised.
These two different update schemes may result in different training
stability.
Parameters
actor (Union[machin.model.nets.base.NeuralNetworkModule, torch.nn.modules.module.Module]) – Actor network module.
actor_target (Union[machin.model.nets.base.NeuralNetworkModule, torch.nn.modules.module.Module]) – Target actor network module.
Use actor network to produce a discrete action for the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
Returns
Action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num],
produced by your actor.
Any other things returned by your Q network. if they exist.
Use actor network to produce a noisy discrete action for
the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
choose_max_prob (float) – Probability to choose the largest component when actor
is outputing extreme probability vector like [0,1,0,0].
Returns
Noisy action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num].
Any other things returned by your Q network. if they exist.
Use actor network to produce a discrete action for the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
Returns
Action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num],
produced by your actor.
Any other things returned by your Q network. if they exist.
Use actor network to produce a noisy discrete action for
the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
choose_max_prob (float) – Probability to choose the largest component when actor
is outputing extreme probability vector like [0,1,0,0].
Returns
Noisy action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num].
Any other things returned by your Q network. if they exist.
DQN supports two ways of updating the target network, the first
way is polyak update (soft update), which updates the target network
in every training step by mixing its weights with the online network
using update_rate.
The other way is hard update, which copies weights of the online
network after every update_steps training step.
You can either specify update_rate or update_steps to select
one update scheme, if both are specified, an error will be raised.
These two different update schemes may result in different training
stability.
optimizer (Callable) – Optimizer used to optimize qnet.
criterion (Callable) – Criterion used to evaluate the value loss.
learning_rate (float) – Learning rate of the optimizer, not compatible with
lr_scheduler.
lr_scheduler (Callable) – Learning rate scheduler of optimizer.
lr_scheduler_args (Tuple[Tuple]) – Arguments of the learning rate scheduler.
lr_scheduler_kwargs (Tuple[Dict]) – Keyword arguments of the learning
rate scheduler.
batch_size (int) – Batch size used during training.
epsilon_decay (float) – Epsilon decay rate per acting with noise step.
epsilon attribute is multiplied with this every time
act_discrete_with_noise is called.
update_rate (Optional[float]) –
\(\tau\) used to update target networks.
Target parameters are updated as:
DQN with prioritized replay. It is based on Double DQN.
Warning
Your criterion must return a tensor of shape [batch_size,1]
when given two tensors of shape [batch_size,1], since we
need to multiply the loss with importance sampling weight
element-wise.
If you are using loss modules given by pytorch. It is always
safe to use them without any modification.
Note
DQN is only available for discrete environments.
Note
Dueling DQN is a network structure rather than a framework, so
it could be applied to all three modes.
If mode="vanilla", implements the simplest online DQN,
with replay buffer.
If mode="fixed_target", implements DQN with a target network,
and replay buffer. Described in this essay.
If mode="double", implements Double DQN described in
this essay.
Note
Vanilla DQN only needs one network, so internally, qnet
is assigned to qnet_target.
Note
In order to implement dueling DQN, you should create two dense
output layers.
DQN supports two ways of updating the target network, the first
way is polyak update (soft update), which updates the target network
in every training step by mixing its weights with the online network
using update_rate.
The other way is hard update, which copies weights of the online
network after every update_steps training step.
You can either specify update_rate or update_steps to select
one update scheme, if both are specified, an error will be raised.
These two different update schemes may result in different training
stability.
optimizer (Callable) – Optimizer used to optimize qnet.
criterion (Callable) – Criterion used to evaluate the value loss.
learning_rate (float) – Learning rate of the optimizer, not compatible with
lr_scheduler.
lr_scheduler (Callable) – Learning rate scheduler of optimizer.
lr_scheduler_args (Tuple[Tuple]) – Arguments of the learning rate scheduler.
lr_scheduler_kwargs (Tuple[Dict]) – Keyword arguments of the learning
rate scheduler.
batch_size (int) – Batch size used during training.
epsilon_decay (float) – Epsilon decay rate per acting with noise step.
epsilon attribute is multiplied with this every time
act_discrete_with_noise is called.
update_rate (Optional[float]) –
\(\tau\) used to update target networks.
Target parameters are updated as:
In the RAINBOW framework, the output shape of your q network
must be [batch_size,action_num,atom_num] when given a
state of shape [batch_size,action_dim]. And the last
dimension must be soft-maxed. Atom number is the number of
segments of your q value domain.
optimizer – Optimizer used to optimize actor and critic.
value_min – Minimum of value domain.
value_max – Maximum of value domain.
learning_rate (float) – Learning rate of the optimizer, not compatible with
lr_scheduler.
lr_scheduler (Callable) – Learning rate scheduler of optimizer.
lr_scheduler_args (Tuple[Tuple]) – Arguments of the learning rate scheduler.
lr_scheduler_kwargs (Tuple[Dict]) – Keyword arguments of the learning
rate scheduler.
batch_size (int) – Batch size used during training.
epsilon_decay (float) – Epsilon decay rate per acting with noise step.
epsilon attribute is multiplied with this every time
act_discrete_with_noise is called.
update_rate (float) –
\(\tau\) used to update target networks.
Target parameters are updated as:
When given a state, and an optional action, actor must
at least return two values:
1. Action
For contiguous environments, action must be of shape
[batch_size,action_dim] and clamped by action space.
For discrete environments, action could be of shape
[batch_size,action_dim] if it is a one hot vector, or
[batch_size,1] or [batch_size] if it is a categorically
encoded integer.
When the given action is not None, actor must return the given
action.
2. Log likelihood of action (action probability)
For either type of environment, log likelihood is of shape
[batch_size,1] or [batch_size].
Action probability must be differentiable, Gradient of actor
is calculated from the gradient of action probability.
When the given action is not None, actor must return the log
likelihood of the given action.
The third entropy value is optional:
3. Entropy of action distribution
Entropy is usually calculated using dist.entropy(), its shape
is [batch_size,1] or [batch_size]. You must specify
entropy_weight to make it effective.
Hint
For contiguous environments, action’s are not directly output by
your actor, otherwise it would be rather inconvenient to calculate
the log probability of action. Instead, your actor network should
output parameters for a certain distribution
(eg: Normal)
and then draw action from it.
For discrete environments,
Categorical is sufficient,
since differentiable rsample() is not needed.
This trick is also known as reparameterization.
Hint
Actions are from samples during training in the actor critic
family (A2C, A3C, PPO, TRPO, IMPALA).
When your actor model is given a batch of actions and states, it
must evaluate the states, and return the log likelihood of the
given actions instead of re-sampling actions.
An example of your actor in contiguous environments:
classActorNet(nn.Module):def__init__(self):super(ActorNet,self).__init__()self.fc=nn.Linear(3,100)self.mu_head=nn.Linear(100,1)self.sigma_head=nn.Linear(100,1)defforward(self,state,action=None):x=t.relu(self.fc(state))mu=2.0*t.tanh(self.mu_head(x))sigma=F.softplus(self.sigma_head(x))dist=Normal(mu,sigma)action=(actionifactionisnotNoneelsedist.sample())action_entropy=dist.entropy()action=action.clamp(-2.0,2.0)# Since we are representing a multivariate gaussian# distribution in terms of independent univariate gaussians:action_log_prob=dist.log_prob(action).sum(dim=1,keepdim=True)returnaction,action_log_prob,action_entropy
Hint
Entropy weight is usually negative, to increase exploration.
Value weight is usually 0.5. So critic network converges less
slowly than the actor network and learns more conditions.
optimizer (Callable) – Optimizer used to optimize actor and critic.
criterion (Callable) – Criterion used to evaluate the value loss.
lr_scheduler (Callable) – Learning rate scheduler of optimizer.
lr_scheduler_args (Tuple[Tuple, Tuple]) – Arguments of the learning rate scheduler.
lr_scheduler_kwargs (Tuple[Dict, Dict]) – Keyword arguments of the learning
rate scheduler.
batch_size (int) – Batch size used during training.
actor_update_times (int) – Times to update actor in update().
critic_update_times (int) – Times to update critic in update().
actor_learning_rate (float) – Learning rate of the actor optimizer,
not compatible with lr_scheduler.
critic_learning_rate (float) – Learning rate of the critic optimizer,
not compatible with lr_scheduler.
entropy_weight (float) – Weight of entropy in your loss function, a positive
entropy weight will minimize entropy, while a negative one will
maximize entropy.
value_weight (float) – Weight of critic value loss.
gradient_max (float) – Maximum gradient.
gae_lambda (float) – \(\lambda\) used in generalized advantage
estimation.
discount (float) – \(\gamma\) used in the bellman function.
normalize_advantage (bool) – Whether to normalize sampled advantage values in
the batch.
replay_size (int) – Replay buffer size. Not compatible with
replay_buffer.
replay_device (Union[str, torch.device]) – Device where the replay buffer locates on, Not
compatible with replay_buffer.
A3C algorithm relies on parameter servers to synchronize
parameters of actor and critic models across samplers (
interact with environment) and trainers (using samples
to train.
The parameter server type PushPullGradServer
used here utilizes gradients calculated by trainers:
1. perform a “sum” reduction process on the collected
gradients, then apply this reduced gradient to the model
managed by its primary reducer
2. push the parameters of this updated managed model to
a ordered key-value server so that all processes,
including samplers and trainers, can access the updated
parameters.
criterion (Callable) – Criterion used to evaluate the value loss.
grad_server (Tuple[machin.parallel.server.param_server.PushPullGradServer, machin.parallel.server.param_server.PushPullGradServer]) – Custom gradient sync server accessors, the first
server accessor is for actor, and the second one is for critic.
batch_size (int) – Batch size used during training.
actor_update_times (int) – Times to update actor in update().
critic_update_times (int) – Times to update critic in update().
entropy_weight (float) – Weight of entropy in your loss function, a positive
entropy weight will minimize entropy, while a negative one will
maximize entropy.
value_weight (float) – Weight of critic value loss.
gradient_max (float) – Maximum gradient.
gae_lambda (float) – \(\lambda\) used in generalized advantage
estimation.
discount (float) – \(\gamma\) used in the bellman function.
normalize_advantage (bool) – Whether to normalize sampled advantage values in
the batch.
replay_size (int) – Replay buffer size. Not compatible with
replay_buffer.
replay_device (Union[str, torch.device]) – Device where the replay buffer locates on, Not
compatible with replay_buffer.
optimizer (Callable) – Optimizer used to optimize actor and critic.
criterion (Callable) – Criterion used to evaluate the value loss.
lr_scheduler (Callable) – Learning rate scheduler of optimizer.
lr_scheduler_args (Tuple[Tuple, Tuple]) – Arguments of the learning rate scheduler.
lr_scheduler_kwargs (Tuple[Dict, Dict]) – Keyword arguments of the learning
rate scheduler.
batch_size (int) – Batch size used during training.
actor_update_times (int) – Times to update actor in update().
critic_update_times (int) – Times to update critic in update().
actor_learning_rate (float) – Learning rate of the actor optimizer,
not compatible with lr_scheduler.
critic_learning_rate (float) – Learning rate of the critic optimizer,
not compatible with lr_scheduler.
entropy_weight (float) – Weight of entropy in your loss function, a positive
entropy weight will minimize entropy, while a negative one will
maximize entropy.
value_weight (float) – Weight of critic value loss.
surrogate_loss_clip (float) – Surrogate loss clipping parameter in PPO.
gradient_max (float) – Maximum gradient.
gae_lambda (float) – \(\lambda\) used in generalized advantage
estimation.
discount (float) – \(\gamma\) used in the bellman function.
normalize_advantage (bool) – Whether to normalize sampled advantage values in
the batch.
replay_size (int) – Replay buffer size. Not compatible with
replay_buffer.
replay_device (Union[str, torch.device]) – Device where the replay buffer locates on, Not
compatible with replay_buffer.
When given a state, and an optional action, actor must
at least return two values, similar to the actor structure
described in A2C. However, when actor is asked to
select an action based on the current state, you must make
sure that the sampling process is differentiable. E.g.
use the rsample method of torch distributions instead
of the sample method.
Compared to other actor-critic methods, SAC embeds the
entropy term into its reward function directly, rather than adding
the entropy term to actor’s loss function. Therefore, we do not use
the entropy output of your actor network.
The SAC algorithm uses Q network as critics, so please reference
DDPG for the requirements and the definition of
action_trans_func.
Parameters
actor (Union[machin.model.nets.base.NeuralNetworkModule, torch.nn.modules.module.Module]) – Actor network module.
constrained_policy_optimization (Union[machin.frame.algorithms.ppo.PPO, machin.frame.algorithms.trpo.TRPO]) – A constrained policy optimization
framework, currently can be a PPO or TRPO framework.
optimizer (Callable) – Optimizer used to optimize discriminator.
discriminator_learning_rate (float) – Learning rate of the discriminator optimizer,
not compatible with lr_scheduler.
lr_scheduler (Callable) – Learning rate scheduler of optimizer.
lr_scheduler_args (Tuple[Tuple]) – Arguments of the learning rate scheduler.
lr_scheduler_kwargs (Tuple[Dict]) – Keyword arguments of the learning
rate scheduler.
batch_size (int) – Batch size used during discriminator training.
gradient_max (float) – Maximum gradient.
expert_replay_size (int) – Expert trajectory buffer size. Not compatible with
expert_replay_buffer.
expert_replay_device (Union[str, torch.device]) – Device where the expert replay buffer locates on, Not
compatible with expert_replay_buffer.
Apex framework supports multiple workers(samplers), and only
one trainer, you may use DistributedDataParallel in trainer.
If you use DistributedDataParallel, you must call update()
in all member processes of DistributedDataParallel.
Parameters
actor (Union[machin.model.nets.base.NeuralNetworkModule, torch.nn.modules.module.Module]) – Actor network module.
actor_target (Union[machin.model.nets.base.NeuralNetworkModule, torch.nn.modules.module.Module]) – Target actor network module.
Use actor network to produce a discrete action for the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
Returns
Action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num],
produced by your actor.
Any other things returned by your Q network. if they exist.
Use actor network to produce a noisy discrete action for
the current state.
Notes
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Parameters
state (Dict[str, Any]) – Current state.
use_target (bool) – Whether to use the target network.
choose_max_prob – Probability to choose the largest component when actor
is outputing extreme probability vector like [0,1,0,0].
Returns
Noisy action of shape [batch_size,1].
Action probability tensor of shape [batch_size,action_num].
Any other things returned by your Q network. if they exist.
Apex framework supports multiple workers(samplers), and only
one trainer, you may use DistributedDataParallel in trainer.
If you use DistributedDataParallel, you must call update()
in all member processes of DistributedDataParallel.
optimizer (Callable) – Optimizer used to optimize qnet.
criterion (Callable) – Criterion used to evaluate the value loss.
apex_group (machin.parallel.distributed._world.RpcGroup) – Group of all processes using the apex-DQN framework,
including all samplers and trainers.
model_server (Tuple[machin.parallel.server.param_server.PushPullModelServer]) – Custom model sync server accessor for qnet.
lr_scheduler (Callable) – Learning rate scheduler of optimizer.
lr_scheduler_args (Tuple[Tuple]) – Arguments of the learning rate scheduler.
lr_scheduler_kwargs (Tuple[Dict]) – Keyword arguments of the learning
rate scheduler.
batch_size (int) – Batch size used during training.
epsilon_decay (float) – Epsilon decay rate per acting with noise step.
epsilon attribute is multiplied with this every time
act_discrete_with_noise is called.
update_rate (float) –
\(\tau\) used to update target networks.
Target parameters are updated as:
optimizer (Callable) – Optimizer used to optimize actor and critic.
criterion (Callable) – Criterion used to evaluate the value loss.
impala_group (machin.parallel.distributed._world.RpcGroup) – Group of all processes using the IMPALA framework,
including all samplers and trainers.
model_server (Tuple[machin.parallel.server.param_server.PushPullModelServer]) – Custom model sync server accessor for actor.
lr_scheduler (Callable) – Learning rate scheduler of optimizer.
lr_scheduler_args (Tuple[Tuple, Tuple]) – Arguments of the learning rate scheduler.
lr_scheduler_kwargs (Tuple[Dict, Dict]) – Keyword arguments of the learning
rate scheduler.
batch_size (int) – Batch size used during training.
learning_rate (float) – Learning rate of the optimizer, not compatible with
lr_scheduler.
isw_clip_c (float) – \(c\) used in importance weight clipping.
isw_clip_rho (float) –
entropy_weight (float) – Weight of entropy in your loss function, a positive
entropy weight will minimize entropy, while a negative one will
maximize entropy.
value_weight (float) – Weight of critic value loss.
gradient_max (float) – Maximum gradient.
discount (float) – \(\gamma\) used in the bellman function.
replay_size (int) – Size of the local replay buffer.
Samples full episodes for batch_size instead of steps.
Create a distributed replay buffer instance.
To avoid issues caused by tensor device difference, all transition
objects are stored in device “cpu”.
Distributed replay buffer constitutes of many local buffers held per
process, transmissions between processes only happen during sampling.
During sampling, the tensors in “state”, “action” and “next_state”
dictionaries, along with “reward”, will be concatenated in dimension 0.
any other custom keys specified in **kwargs will not be
concatenated.
DistributedBuffer does not support customizing storage device when using
the default storage, since its safer to pass cpu tensors between RPC callers
and callees.
Note
Since append() operates on the local buffer, in order to
append to the distributed buffer correctly, please make sure
that your actor is also the local buffer holder, i.e. a member
of the group
Parameters
buffer_name (str) – A unique name of your buffer for registration in the group.
group (machin.parallel.distributed._world.RpcGroup) – Process group which holds this buffer.
buffer_size (int) – Maximum local buffer size.
storage (machin.frame.buffers.storage.TransitionStorageBase) – Custom storage, not compatible with buffer_size and
buffer_device.
“Concatenation” means torch.cat([listoftensors],dim=0) for tensors,
and torch.tensor([listofscalars]).view(batch_size,1) for scalars.
By default, only major and sub attributes will be concatenated, in order to
concatenate custom attributes, specify their names in
additional_concat_custom_attrs.
Warning
Custom attributes must not contain tensors. And only scalar custom
attributes can be concatenated, such as int, float,
bool.
Parameters
batch_size (int) – A hint size of the result sample. actual sample size
depends on your sample method.
sample_method – Sample method, could be one of:
"random","random_unique","all",
or a function:
func(buffer,batch_size)->(list,result_size)
concatenate (bool) – Whether perform concatenation on major, sub and custom
attributes.
If True, for each value in dictionaries of major
attributes. and each value of sub attributes, returns
a concatenated tensor. Custom Attributes specified in
additional_concat_custom_attrs will also be concatenated.
If False, performs no concatenation.
device (Union[str, torch.device]) – Device to move tensors in the batch to.
sample_attrs (List[str]) – If sample_keys is specified, then only specified keys
of the transition object will be sampled. You may use
"*" as a wildcard to collect remaining
custom keys as a dict, you cannot collect major
and sub attributes using this.
Invalid sample attributes will be ignored.
additional_concat_custom_attrs (List[str]) – additional custom keys needed to be
concatenated, will only work if concatenate is
True.
Returns
Batch size, Sampled attribute values in the same order as
sample_keys.
Sampled attribute values is a tuple. Or None if sampled
batch size is zero (E.g.: if buffer is empty or your sample
size is 0 and you are not sampling using the “all” method).
For major attributes, result are dictionaries of tensors with
the same keys in your transition objects.
For sub attributes, result are tensors.
For custom attributes, if they are not in
additional_concat_custom_attrs, then lists, otherwise tensors.
For wildcard selector, result is a dictionary containing unused custom
attributes, if they are not in additional_concat_custom_attrs,
the values are lists, otherwise values are tensors.
MADDPG is a centralized multi-agent training framework, it alleviates the
unstable reward problem caused by the disturbance of other agents by
gathering all agents observations and train a global critic. This global
critic observes all actions and all states from all agents.
In order to parallelize agent inference, a process pool is used
internally. However, in order to minimize memory copy / CUDA memory
copy, the location of all of your models must be either “cpu”, or
“cuda” (Using multiple CUDA devices is supported).
Note
MADDPG framework does not require all of your actors are
homogeneous. Each pair of your actors and critcs could be
heterogeneous.
Note
Suppose you have three pair of actors and critics, with index 0, 1,
2. If critic 0 can observe the action of actor 0 and 1, critic 1 can
observe the action of actor 1 and 2, critic 2 can observe the action
of actor 2 and 0, the critic_visible_actors should be:
[[0,1],[1,2],[2,0]]
Note
Learning rate scheduler args and kwargs for each actor and critic,
the first list is for actors, and the second list is for critics.
Note
This implementation contains:
Ensemble Training
This implementation does not contain:
Inferring other agents’ policies
Mixed continuous/discrete action spaces
Parameters
actors (List[Union[machin.model.nets.base.NeuralNetworkModule, torch.nn.modules.module.Module]]) – Actor network modules.
actor_targets (List[Union[machin.model.nets.base.NeuralNetworkModule, torch.nn.modules.module.Module]]) – Target actor network modules.
visualize (bool) – Whether visualize the network flow in the first pass.
visualize_dir (str) – Visualized graph save directory.
use_jit (bool) – Whether use torch jit to perform the forward pass
in parallel instead of using the internal pool. Provides
significant speed and efficiency advantage, but requires
actors and critics convertible to TorchScript.
pool_type (str) – Type of the internal execution pool, either “process”
or “thread”.
pool_size (int) – Size of the internal execution pool.
Post-process concatenated attribute items. Values are processed results from
the method Buffer.make_tensor_from_batch(), either a list of not
concatenated values, or a concatenated tensor.
Parameters
attribute (Any) – Attribute key, such as “state”, “next_state”, etc.
sub_key (Any) – Sub key in attribute if attribute is a major attribute,
set to None if attribute is a sub attribute or a custom attribute.
Pre-process attribute items, method Buffer.make_tensor_from_batch()
will use the result from this function and assumes processed attribute items
to be one of:
A list of tensors that’s concatenable in dimension 0.
A list of values that’s transformable to a tensor.
In case you want to implement custom padding for each item of an
attribute, or other custom preprocess, please override this method.
“Concatenation” means torch.cat([listoftensors],dim=0) for tensors,
and torch.tensor([listofscalars]).view(batch_size,1) for scalars.
By default, only major and sub attributes will be concatenated, in order to
concatenate custom attributes, specify their names in
additional_concat_custom_attrs.
Warning
Custom attributes must not contain tensors. And only scalar custom
attributes can be concatenated, such as int, float,
bool.
Parameters
batch_size (int) – A hint size of the result sample. actual sample size
depends on your sample method.
sample_method (Union[Callable[[Buffer, int], Tuple[List[Any], int]], str]) – Sample method, could be one of:
"random","random_unique","all",
or a function:
func(buffer,batch_size)->(list,result_size)
concatenate (bool) – Whether perform concatenation on major, sub and custom
attributes.
If True, for each value in dictionaries of major
attributes. and each value of sub attributes, returns
a concatenated tensor. Custom Attributes specified in
additional_concat_custom_attrs will also be concatenated.
If False, performs no concatenation.
device (Union[str, torch.device]) – Device to move tensors in the batch to.
sample_attrs (List[str]) – If sample_keys is specified, then only specified keys
of the transition object will be sampled. You may use
"*" as a wildcard to collect remaining
custom keys as a dict, you cannot collect major
and sub attributes using this.
Invalid sample attributes will be ignored.
additional_concat_custom_attrs (List[str]) – additional custom keys needed to be
concatenated, will only work if concatenate is
True.
Returns
Batch size, Sampled attribute values in the same order as
sample_keys.
Sampled attribute values is a tuple. Or None if sampled
batch size is zero (E.g.: if buffer is empty or your sample
size is 0 and you are not sampling using the “all” method).
For major attributes, result are dictionaries of tensors with
the same keys in your transition objects.
For sub attributes, result are tensors.
For custom attributes, if they are not in
additional_concat_custom_attrs, then lists, otherwise tensors.
For wildcard selector, result is a dictionary containing unused custom
attributes, if they are not in additional_concat_custom_attrs,
the values are lists, otherwise values are tensors.
If you pass in a dict type transition object, it will be automatically
converted to Transition, which requires attributes “state”, “action”
“next_state”, “reward” and “terminal” to be present in the dict keys.
Post-process concatenated attribute items. Values are processed results from
the method Buffer.make_tensor_from_batch(), either a list of not
concatenated values, or a concatenated tensor.
Parameters
attribute (Any) – Attribute key, such as “state”, “next_state”, etc.
sub_key (Any) – Sub key in attribute if attribute is a major attribute,
set to None if attribute is a sub attribute or a custom attribute.
Pre-process attribute items, method Buffer.make_tensor_from_batch()
will use the result from this function and assumes processed attribute items
to be one of:
A list of tensors that’s concatenable in dimension 0.
A list of values that’s transformable to a tensor.
In case you want to implement custom padding for each item of an
attribute, or other custom preprocess, please override this method.
“Concatenation” means torch.cat([listoftensors],dim=0) for tensors,
and torch.tensor([listofscalars]).view(batch_size,1) for scalars.
By default, only major and sub attributes will be concatenated, in order to
concatenate custom attributes, specify their names in
additional_concat_custom_attrs.
Warning
Custom attributes must not contain tensors. And only scalar custom
attributes can be concatenated, such as int, float,
bool.
Parameters
batch_size (int) – A hint size of the result sample. actual sample size
depends on your sample method.
sample_method (Union[Callable[[Buffer, int], Tuple[List[Any], int]], str]) – Sample method, could be one of:
"random","random_unique","all",
or a function:
func(buffer,batch_size)->(list,result_size)
concatenate (bool) – Whether perform concatenation on major, sub and custom
attributes.
If True, for each value in dictionaries of major
attributes. and each value of sub attributes, returns
a concatenated tensor. Custom Attributes specified in
additional_concat_custom_attrs will also be concatenated.
If False, performs no concatenation.
device (Union[str, torch.device]) – Device to move tensors in the batch to.
sample_attrs (List[str]) – If sample_keys is specified, then only specified keys
of the transition object will be sampled. You may use
"*" as a wildcard to collect remaining
custom keys as a dict, you cannot collect major
and sub attributes using this.
Invalid sample attributes will be ignored.
additional_concat_custom_attrs (List[str]) – additional custom keys needed to be
concatenated, will only work if concatenate is
True.
Returns
Batch size, Sampled attribute values in the same order as
sample_keys.
Sampled attribute values is a tuple. Or None if sampled
batch size is zero (E.g.: if buffer is empty or your sample
size is 0 and you are not sampling using the “all” method).
For major attributes, result are dictionaries of tensors with
the same keys in your transition objects.
For sub attributes, result are tensors.
For custom attributes, if they are not in
additional_concat_custom_attrs, then lists, otherwise tensors.
For wildcard selector, result is a dictionary containing unused custom
attributes, if they are not in additional_concat_custom_attrs,
the values are lists, otherwise values are tensors.
If you pass in a dict type transition object, it will be automatically
converted to Transition, which requires attributes “state”, “action”
“next_state”, “reward” and “terminal” to be present in the dict keys.
To avoid issues caused by tensor device difference, all transition
objects are stored in device “cpu”.
Distributed replay buffer constitutes of many local buffers held per
process, transmissions between processes only happen during sampling.
During sampling, the tensors in “state”, “action” and “next_state”
dictionaries, along with “reward”, will be concatenated in dimension 0.
any other custom keys specified in **kwargs will not be
concatenated.
DistributedBuffer does not support customizing storage device when using
the default storage, since its safer to pass cpu tensors between RPC callers
and callees.
Note
Since append() operates on the local buffer, in order to
append to the distributed buffer correctly, please make sure
that your actor is also the local buffer holder, i.e. a member
of the group
Parameters
buffer_name (str) – A unique name of your buffer for registration in the group.
group (machin.parallel.distributed._world.RpcGroup) – Process group which holds this buffer.
buffer_size (int) – Maximum local buffer size.
storage (machin.frame.buffers.storage.TransitionStorageBase) – Custom storage, not compatible with buffer_size and
buffer_device.
“Concatenation” means torch.cat([listoftensors],dim=0) for tensors,
and torch.tensor([listofscalars]).view(batch_size,1) for scalars.
By default, only major and sub attributes will be concatenated, in order to
concatenate custom attributes, specify their names in
additional_concat_custom_attrs.
Warning
Custom attributes must not contain tensors. And only scalar custom
attributes can be concatenated, such as int, float,
bool.
Parameters
batch_size (int) – A hint size of the result sample. actual sample size
depends on your sample method.
sample_method (Union[Callable, str]) – Sample method, could be one of:
"random","random_unique","all",
or a function:
func(buffer,batch_size)->(list,result_size)
concatenate (bool) – Whether perform concatenation on major, sub and custom
attributes.
If True, for each value in dictionaries of major
attributes. and each value of sub attributes, returns
a concatenated tensor. Custom Attributes specified in
additional_concat_custom_attrs will also be concatenated.
If False, performs no concatenation.
device (Union[str, torch.device]) – Device to move tensors in the batch to.
sample_attrs (List[str]) – If sample_keys is specified, then only specified keys
of the transition object will be sampled. You may use
"*" as a wildcard to collect remaining
custom keys as a dict, you cannot collect major
and sub attributes using this.
Invalid sample attributes will be ignored.
additional_concat_custom_attrs (List[str]) – additional custom keys needed to be
concatenated, will only work if concatenate is
True.
Returns
Batch size, Sampled attribute values in the same order as
sample_keys.
Sampled attribute values is a tuple. Or None if sampled
batch size is zero (E.g.: if buffer is empty or your sample
size is 0 and you are not sampling using the “all” method).
For major attributes, result are dictionaries of tensors with
the same keys in your transition objects.
For sub attributes, result are tensors.
For custom attributes, if they are not in
additional_concat_custom_attrs, then lists, otherwise tensors.
For wildcard selector, result is a dictionary containing unused custom
attributes, if they are not in additional_concat_custom_attrs,
the values are lists, otherwise values are tensors.
If you pass in a dict type transition object, it will be automatically
converted to Transition, which requires attributes “state”, “action”
“next_state”, “reward” and “terminal” to be present in the dict keys.
PrioritizedBuffer does not support customizing storage as it
requires a linear storage.
Parameters
buffer_size (int) – Maximum buffer size.
buffer_device (Union[str, torch.device]) – Device where buffer is stored.
epsilon (float) – A small positive constant used to prevent edge-case
zero weight transitions from never being visited.
alpha (float) – Prioritization weight. Used during transition sampling:
\(j \sim P(j)=p_{j}^{\alpha} / \sum_i p_{i}^{\alpha}\).
When alpha=0, all samples have the same probability
to be sampled.
When alpha=1, all samples are drawn uniformly according
to their weight.
beta (float) – Bias correcting weight. When beta=1, bias introduced
by prioritized replay will be corrected. Used during
importance weight calculation:
\(w_j=(N \cdot P(j))^{-\beta}/max_i w_i\)
beta_increment_per_sampling (float) – Beta increase step size, will gradually increase beta to 1.
batch_size (int) – A hint size of the result sample.
concatenate (bool) – Whether perform concatenation on major, sub and custom
attributes.
If True, for each value in dictionaries of major
attributes. and each value of sub attributes, returns
a concatenated tensor. Custom Attributes specified in
additional_concat_custom_attrs will also be concatenated.
If False, performs no concatenation.
device (Union[str, torch.device]) – Device to move tensors in the batch to.
sample_attrs (List[str]) – If sample_keys is specified, then only specified keys
of the transition object will be sampled. You may use
"*" as a wildcard to collect remaining
custom keys as a dict, you cannot collect major
and sub attributes using this.
Invalid sample attributes will be ignored.
additional_concat_custom_attrs (List[str]) – additional custom keys needed to be
concatenated, will only work if concatenate is
True.
Returns
Batch size.
Sampled attribute values in the same order as sample_keys.
Sampled attribute values is a tuple. Or None if sampled
batch size is zero (E.g.: if buffer is empty or your sample
size is 0).
class machin.frame.buffers.prioritized_buffer.WeightTree(size)[source]¶
Bases: object
Sum weight tree data structure.
Initialize a weight tree.
Note
Weights must be positive.
Note
Weight tree is stored as a flattened, full binary tree in a
np.ndarray. The lowest level of leaves comes first, the
highest root node is stored at last.
Example:
Tree with weights: [[1,2,3,4],[3,7],[11]]
will be stored as: [1,2,3,4,3,7,11]
Note
Performance On i7-6700HQ (M: Million):
90ms for building a tree with 10M elements.
230ms for looking up 10M elements in a tree with 10M elements.
20ms for 1M element batched update in a tree with 10M elements.
240ms for 1M element single update in a tree with 10M elements.
Create a distributed prioritized replay buffer instance.
To avoid issues caused by tensor device difference, all transition
objects are stored in device “cpu”.
Distributed prioritized replay buffer constitutes of many local buffers
held per process, since it is very inefficient to maintain a weight
tree across processes, each process holds a weight tree of records in
its local buffer and a local buffer (same as DistributedBuffer).
The sampling process(es) will first use rpc to acquire the wr_lock,
signalling “stop” to appending performed by actor processes,
then perform a sum of all local weight trees, and finally perform
sampling, after sampling and updating the importance weight,
the lock will be released.
During sampling, the tensors in “state”, “action” and “next_state”
dictionaries, along with “reward”, will be concatenated in dimension 0.
any other custom keys specified in **kwargs will not be
concatenated.
See also
PrioritizedBuffer
Note
DistributedPrioritizedBuffer does not support customizing storage as it
requires a linear storage.
Note
DistributedPrioritizedBuffer is not split into an
accessor and an implementation, because we would like to operate
on the buffer directly, when calling “size()” or “append()”, to
increase efficiency (since rpc layer is bypassed).
Parameters
buffer_name (str) – A unique name of your buffer for registration in the group.
group (machin.parallel.distributed._world.RpcGroup) – Process group which holds this buffer.
buffer_size (int) – Maximum local buffer size.
epsilon (float) – A small positive constant used to prevent edge-case
zero weight transitions from never being visited.
alpha (float) – Prioritization weight. Used during transition sampling:
\(j \sim P(j)=p_{j}^{\alpha} / \sum_i p_{i}^{\alpha}\).
When alpha=0, all samples have the same probability
to be sampled.
When alpha=1, all samples are drawn uniformly according
to their weight.
beta (float) – Bias correcting weight. When beta=1, bias introduced
by prioritized replay will be corrected. Used during
importance weight calculation:
\(w_j=(N \cdot P(j))^{-\beta}/max_i w_i\)
beta_increment_per_sampling (float) – Beta increase step size, will gradually increase beta to 1.
batch_size (int) – A hint size of the result sample.
concatenate (bool) – Whether perform concatenation on major, sub and custom
attributes.
If True, for each value in dictionaries of major
attributes. and each value of sub attributes, returns
a concatenated tensor. Custom Attributes specified in
additional_concat_custom_attrs will also be concatenated.
If False, performs no concatenation.
device (Union[str, torch.device]) – Device to move tensors in the batch to.
sample_attrs (List[str]) – If sample_keys is specified, then only specified keys
of the transition object will be sampled. You may use
"*" as a wildcard to collect remaining
custom keys as a dict, you cannot collect major
and sub attributes using this.
Invalid sample attributes will be ignored.
additional_concat_custom_attrs (List[str]) – additional custom keys needed to be
concatenated, will only work if concatenate is
True.
Returns
Batch size.
Sampled attribute values in the same order as sample_keys.
Sampled attribute values is a tuple. Or None if sampled
batch size is zero (E.g.: if buffer is empty or your sample
size is 0).
The innermost tuple contains:
(normal_mean,normal_sigma,clip_min,clip_max)
If noise_param is Tuple[float,float,float,float],
then the same clipped normal noise will be added to action[*,:].
If noise_param is Iterable[Tuple[float,float,float,float]],
then for each action[*,i] slice i, clipped normal noise with
noise_param[i] will be applied respectively.
Parameters
action (torch.Tensor) – Raw action
noise_param (Union[Iterable[Tuple], Tuple]) – Param of the normal noise.
ratio – Sampled noise is multiplied with this ratio.
The innermost tuple contains:
(normal_mean,normal_sigma)
If noise_param is Tuple[float,float],
then the same normal noise will be added to action[*,:].
If noise_param is Iterable[Tuple[float,float]],
then for each action[*,i] slice i, clipped normal noise with
noise_param[i] will be applied respectively.
Parameters
action (torch.Tensor) – Raw action
noise_param – Param of the normal noise.
ratio – Sampled noise is multiplied with this ratio.
Ornstein-Uhlenbeck noise generator is shared. And you cannot
specify OU noise of different distributions
for each of the last dimension of your action.
Parameters
action (torch.Tensor) – Raw action
noise_param (Dict[str, Any]) – OrnsteinUhlenbeckGen params. Used as
keyword arguments of the generator. Will only be effective if
reset is True.
ratio – Sampled noise is multiplied with this ratio.
reset – Whether to reset the default Ornstein-Uhlenbeck noise generator.
Only parameters of type t.Tensor and gettable from
model.named_parameters() will be perturbed.
Original parameters will be automatically swapped in during the
backward pass, and you can safely call optimizers afterwards.
Hint
1. noise_generator must accept (shape, *args) in its __init__
function, where shape is the required shape. it also needs to have
__call__(device=None) which produce a noise tensor on the specified
device when invoked.
2. noise_generate_function must accept (shape, device, std:float)
and return a noise tensor on the specified device.
Example
In order to use this function to perturb your model, you need to:
frommachin.utils.helper_classesimportSwitchfrommachin.frame.noise.param_space_noiseimportperturb_modelfrommachin.utils.visualizeimportvisualize_graphimporttorchastdims=5t.manual_seed(0)model=t.nn.Linear(dims,dims)optim=t.optim.Adam(model.parameters(),1e-3)p_switch,r_switch=Switch(),Switch()cancel=perturb_model(model,p_switch,r_switch)# you should keep this switch on if you do one training step after# every sampling step. otherwise you may turn it off in one episode# and turn it on in the next to speed up training.r_switch.on()# turn off/on the perturbation switch to see the differencep_switch.on()# do some samplingaction=model(t.ones([dims]))# in order to let parameter noise adapt to generate noisy actions# within ``desired_action_stddev``, you must periodically# use the original model to generate some actions:p_switch.off()action=model(t.ones([dims]))# visualize will not show any leaf noise tensors# because they are created in t.no_grad() context# and added in-place.visualize_graph(action,exit_after_vis=False)# do some trainingloss=(action-t.ones([dims])).sum()loss.backward()optim.step()print(model.weight)# clear hookscancel()
Parameters
model (torch.nn.modules.module.Module) – Neural network model.
perturb_switch (machin.utils.helper_classes.Switch) – The switch used to enable perturbation. If switch is
set to False (off), then during the forward process, original
parameters are used.
reset_switch (machin.utils.helper_classes.Switch) – The switch used to reset perturbation noise. If switch is
set to True (on), and perturb_switch is also on, then during
every forward process, a new set of noise is applied to each param.
If only perturb_switch is on, then the same set of noisy
parameters is used in the forward process and they will not be
updated.
distance_func (Callable) – Distance function, accepts two tensors produced by
model (one is noisy), return the distance as float. Used
to compare the distance between actions generated by
noisy parameters and original parameters.
desired_action_stddev (float) – Desired action standard deviation.
noise_generator (Any) – Noise generator class.
noise_generator_args (Tuple) – Additional args other than shape of the noise
generator.
noise_generator_kwargs (Dict) – Additional kwargs other than shape of the noise
generator.
noise_generate_function (Callable) – Noise generation function, mutually exclusive
with noise_generator and noise_generator_args.
debug_backward – Print a message if the backward hook is correctly
executed.
Returns
A reset function with no arguments, will swap in original paramters.
A deregister function with no arguments, will deregister all hooks
class machin.frame.transition.TransitionBase(major_attr, sub_attr, custom_attr, major_data, sub_data, custom_data)[source]¶
Bases: object
Base class for all transitions
Note
Major attributes store things like state, action, next_states, etc.
They are usually concatenated by their dictionary keys during
sampling, and passed as keyword arguments to actors, critics, etc.
Sub attributes store things like terminal states, reward, etc.
They are usually concatenated directly during sampling, and used
in different algorithms.
Custom attributes store not concatenatable values, usually user
specified states, used in models or as special arguments in
different algorithms. They will be collected together as a list
during sampling, no further concatenation is performed.
Parameters
major_attr (Iterable[str]) – A list of major attribute names.
sub_attr (Iterable[str]) – A list of sub attribute names.
custom_attr (Iterable[str]) – A list of custom attribute names.
major_data (Iterable[Dict[str, torch.Tensor]]) – Data of major attributes.
sub_data (Iterable[Union[NewType.<locals>.new_type, torch.Tensor]]) – Data of sub attributes.
custom_data (Iterable[Any]) – Data of custom attributes.
machin.model.nets provides implementations for various popular network
architectures.
class machin.model.nets.NeuralNetworkModule[source]¶
Bases: torch.nn.modules.module.Module, abc.ABC
Note: input device and output device are determined by module parameters,
your input module / output submodule should not store parameters on
more than one device, and you also should not move your output to
other devices other than your parameter storage device in forward().
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Check model input, output and parameters using hooks. All hooks (Input,
output and parameter) check hooks are executed in the forward pass.
An example:
model=nn.Linear([100,100])check_model(model)# Continue to do whatever you like.model(t.zeros([100]))
Note
Only leaf modules will be checked (such as nn.Linear and not some
complex neural network modules made of several sub-modules). But you
can manually control granularity.
Warning
Do not output tuple in your forward() function if you have
output check hooks, otherwise you must specify names for each output.
You may specify a list of names for your module outputs so
names given to your output check hooks will not be numbers,
by using mark_module_output()
Hint
For all three kinds of hooks, your hook need to have the following
signature:
hook(counter,writer,model,module,name,value)
where:
counter is the Counter, you can use
Counter.get() to get the current pass number.
writer is SummaryWriter from tensorboardx.
model is your model.
module is the module currently being checked.
name is input/output/parameter name string. For input, their
detail names will be extracted from module forward signature.
Output detail names will be numbers or names you have specified.
value is input/output/parameter value.
Parameters
writer (tensorboardX.writer.SummaryWriter) – Tensorboard SummaryWriter used to log.
model (torch.nn.modules.module.Module) – Model to be checked.
input_check_hooks – A series of input check hooks.
output_check_hooks – A series of output check hooks.
param_check_hooks – A series of parameter check hooks.
input_check_interval – Interval (number of forward passes)
of input checking.
output_check_interval – Interval (number of forward passes)
of output checking.
param_check_interval – Interval (number of backward passes)
of parameter checking.
name – Your model name.
Returns
A function f(), calling f() will deregister all check hooks.
frommachin.utils.confimportConfigfrommachin.utils.save_envimportSaveEnv# set some config attributesc=Config(model_save_int=100,root_dir="some_directory",restart_from_trial="2020_05_09_15_00_31")load_config_cmd(c)# restart_from_trial specifies the trial name in your root# directory.# If it is set, then SaveEnv constructor will# load arguments from that trial record, will overwrite.# If not, then SaveEnv constructor will save configurations# as: ``<c.some_root_dir>/<trial_start_time>/config/config.json``save_env=SaveEnv(c)
image (numpy.ndarray) – A numpy array of shape (H, W, C) or (H, W), and with
dtype = any float or any int.
When a frame is float type, its value range should be [0, 1].
When a frame is integer type, its value range should be [0, 255].
if daemon is true, then this function cannot be used in a
daemonic subprocess.
Parameters
image (numpy.array) – A numpy array of shape (H, W, C) or (H, W), and with
dtype = any float or any int.
When a frame is float type, its value range should be [0, 1].
When a frame is integer type, its value range should be [0, 255].
path (str) – Directory to save the image.
filename (str) – File name.
extension (str) – File extension.
daemon (bool) – Whether launching the saving process as a daemonic process.
Returns
A wait function, once called, block until creation has finished.
frames (List[numpy.ndarray]) – A list of numpy arrays of shape (H, W, C) or (H, W), and with
dtype = any float or any int.
When a frame is float type, its value range should be [0, 1].
When a frame is integer type, its value range should be [0, 255].
if daemon is true, then this function cannot be used in a
daemonic subprocess.
Parameters
frames (List[numpy.ndarray]) – A list of numpy arrays of shape (H, W, C) or (H, W), and with
dtype = any float or any int.
When a frame is float type, its value range should be [0, 1].
When a frame is integer type, its value range should be [0, 255].
path (str) – Directory to save the video.
filename (str) – File name.
extension (str) – File extension.
fps (int) – frames per second.
daemon (bool) – Whether launching the saving process as a daemonic process.
Returns
A wait function, once called, block until creation has finished.
Use matplotlib to show a single image. You may repeatedly call this method
with the same title argument to show a video or a dynamically changing
image.
Parameters
image (numpy.ndarray) – A numpy array of shape (H, W, C) or (H, W), and with dtype
= any float or any int.
When a frame is float type, its value range should be [0, 1].
When a frame is integer type, its value range should be [0, 255].
show_normalized (bool) – Show normalized image alongside the original one.
pause_time (float) – Pause time between displaying current image and the next
one.
env_root (str) – root directory for all trials of the environment.
restart_from_trial (Optional[str]) – instead of creating a new save environment
for a new trial, use a existing save environment of an older
trial, old trial name should be in format time_format
time_format – Time formatter, setting it to an empty string will cause
the save environment to use env_root directly instead of using
sub directories with a datetime name.
Each tensor in tensor_list should reside on a separate GPU
Only nccl backend is currently supported
tensors should only be GPU tensors
Complex tensors are supported.
Parameters
output_tensor_lists (List[List[Tensor]]) –
Output lists. It should
contain correctly-sized tensors on each GPU to be used for output
of the collective, e.g. output_tensor_lists[i] contains the
all_gather result that resides on the GPU of
input_tensor_list[i].
Note that each element of output_tensor_lists has the size of
world_size*len(input_tensor_list), since the function all
each element of output_tensor_lists[i], note that
input_tensor_list[j] of rank k will be appear in
output_tensor_lists[i][k*world_size+j]
Also note that len(output_tensor_lists), and the size of each
element in output_tensor_lists (each element is a list,
therefore len(output_tensor_lists[i])) need to be the same
for all the distributed processes calling this function.
input_tensor_list (List[Tensor]) – List of tensors(on different GPUs) to
be broadcast from current process.
Note that len(input_tensor_list) needs to be the same for
all the distributed processes calling this function.
async_op (bool, optional) – Whether this op should be an async op
Reduces the tensor data across all machines in such a way that all get
the final result. This function reduces a number of tensors on every node,
while each tensor resides on different GPUs.
Therefore, the input tensor in the tensor list needs to be GPU tensors.
Also, each tensor in the tensor list needs to reside on a different GPU.
After the call, all tensor in tensor_list is going to be bitwise
identical in all processes.
Complex tensors are supported.
Only nccl and gloo backend is currently supported
tensors should only be GPU tensors
Parameters
tensor_list (List[Tensor]) – List of input and output tensors of
the collective. The function operates in-place and requires that
each tensor to be a GPU tensor on different GPUs.
You also need to make sure that len(tensor_list) is the same for
all the distributed processes calling this function.
op (optional) – One of the values from
torch.distributed.ReduceOp
enum. Specifies an operation used for element-wise reductions.
async_op (bool, optional) – Whether this op should be an async op
tensor must have the same number of elements in all the GPUs from
all processes participating in the collective. each tensor in the list must
be on a different GPU
Only nccl and gloo backend are currently supported
tensors should only be GPU tensors
Parameters
tensor_list (List[Tensor]) – Tensors that participate in the collective
operation. If src is the rank, then the specified src_tensor
element of tensor_list (tensor_list[src_tensor]) will be
broadcast to all other tensors (on different GPUs) in the src process
and all tensors in tensor_list of other non-src processes.
You also need to make sure that len(tensor_list) is the same
for all the distributed processes calling this function.
src (int) – Source rank.
async_op (bool, optional) – Whether this op should be an async op
src_tensor (int, optional) – Source tensor rank within tensor_list
gather_list (list[Tensor], optional) – List of appropriately-sized
tensors to use for gathered data (default is None, must be specified
on the destination rank)
dst (int, optional) – Destination rank (default is 0)
async_op (bool, optional) – Whether this op should be an async op
Reduces the tensor data on multiple GPUs across all machines. Each tensor
in tensor_list should reside on a separate GPU
Only the GPU of tensor_list[dst_tensor] on the process with rank dst
is going to receive the final result.
Only nccl backend is currently supported
tensors should only be GPU tensors
Parameters
tensor_list (List[Tensor]) – Input and output GPU tensors of the
collective. The function operates in-place.
You also need to make sure that len(tensor_list) is the same for
all the distributed processes calling this function.
dst (int) – Destination rank
op (optional) – One of the values from
torch.distributed.ReduceOp
enum. Specifies an operation used for element-wise reductions.
async_op (bool, optional) – Whether this op should be an async op
dst_tensor (int, optional) – Destination tensor rank within
tensor_list
Returns
Async work handle, if async_op is set to True.
None, otherwise
key (Any) – Key of the registered service, in this group.
args – Service arguments.
kwargs – Service keyword arguments.
Returns
A future object you can call wait()``on.``wait() will block the thread until execution is completed,
and will return the result returned by the service.
Make a remote call to run func on worker to and return an
RRef to the result value immediately.
Worker to will be the owner of the returned
RRef, and the worker calling remote is
a user. The owner manages the global reference count of its
RRef, and the owner
RRef is only destructed when globally there
are no living references to it.
Parameters
to (str or WorkerInfo or int) – name/rank/WorkerInfo of the destination worker.
func (callable) – a callable function, such as Python callables, builtin
operators (e.g. add()) and annotated
TorchScript functions.
args (tuple) – the argument tuple for the func invocation.
kwargs (dict) – is a dictionary of keyword arguments for the func
invocation.
timeout (float, optional) – timeout in seconds for this remote call. If the
creation of this
RRef on worker
to is not successfully processed on this
worker within this timeout, then the next time
there is an attempt to use the RRef (such as
to_here()), a timeout will be raised
indicating this failure. A value of 0 indicates
an infinite timeout, i.e. a timeout error will
never be raised. If not provided, the default
value set during initialization or with
_set_rpc_timeout is used.
Returns
A user RRef instance to the result
value. Use the blocking API torch.distributed.rpc.RRef.to_here()
to retrieve the result value locally.
Warning
The remote API does not copy storages of argument tensors until
sending them over the wire, which could be done by a different thread
depending on the RPC backend type. The caller should make sure that the
contents of those tensors stay intact until the returned RRef is
confirmed by the owner, which can be checked using the
torch.distributed.rpc.RRef.confirmed_by_owner() API.
Warning
Errors such as timeouts for the remote API are handled on a
best-effort basis. This means that when remote calls initiated by
remote fail, such as with a timeout error, we take a best-effort
approach to error handling. This means that errors are handled and set
on the resulting RRef on an asynchronous basis. If the RRef has not been
used by the application before this handling (such as to_here or
fork call), then future uses of the RRef will appropriately raise
errors. However, it is possible that the user application will use the
RRef before the errors are handled. In this case, errors may not be
raised as they have not yet been handled.
Example::
Make sure that MASTER_ADDR and MASTER_PORT are set properly
API for more details. For example,
Make a non-blocking RPC call to run function func on worker to. RPC
messages are sent and received in parallel to execution of Python code. This
method is thread-safe. This method will immediately return a
Future that can be awaited on.
Parameters
to (str or WorkerInfo or int) – name/rank/WorkerInfo of the destination worker.
func (callable) – a callable function, such as Python callables, builtin
operators (e.g. add()) and annotated
TorchScript functions.
args (tuple) – the argument tuple for the func invocation.
kwargs (dict) – is a dictionary of keyword arguments for the func
invocation.
timeout (float, optional) – timeout in seconds to use for this RPC. If
the RPC does not complete in this amount of
time, an exception indicating it has
timed out will be raised. A value of 0
indicates an infinite timeout, i.e. a timeout
error will never be raised. If not provided,
the default value set during initialization
or with _set_rpc_timeout is used.
Returns
Returns a Future object that can be waited
on. When completed, the return value of func on args and
kwargs can be retrieved from the Future
object.
Warning
Using GPU tensors as arguments or return values of func is not
supported since we don’t support sending GPU tensors over the wire. You
need to explicitly copy GPU tensors to CPU before using them as
arguments or return values of func.
Warning
The rpc_async API does not copy storages of argument tensors until
sending them over the wire, which could be done by a different thread
depending on the RPC backend type. The caller should make sure that the
contents of those tensors stay intact until the returned
Future completes.
Example::
Make sure that MASTER_ADDR and MASTER_PORT are set properly
API for more details. For example,
Make a blocking RPC call to run function func on worker to. RPC
messages are sent and received in parallel to execution of Python code. This
method is thread-safe.
Parameters
to (str or WorkerInfo or int) – name/rank/WorkerInfo of the destination worker.
func (callable) – a callable function, such as Python callables, builtin
operators (e.g. add()) and annotated
TorchScript functions.
args (tuple) – the argument tuple for the func invocation.
kwargs (dict) – is a dictionary of keyword arguments for the func
invocation.
timeout (float, optional) – timeout in seconds to use for this RPC. If
the RPC does not complete in this amount of
time, an exception indicating it has
timed out will be raised. A value of 0
indicates an infinite timeout, i.e. a timeout
error will never be raised. If not provided,
the default value set during initialization
or with _set_rpc_timeout is used.
Returns
Returns the result of running func with args and kwargs.
Example::
Make sure that MASTER_ADDR and MASTER_PORT are set properly
API for more details. For example,
class machin.parallel.server.OrderedServerBase[source]¶
Bases: abc.ABC
Descendent classes of OrderedServer does not have to guarantee strong
consistency, that is, even if OrderedServerBase.push_service`()
has returned True, there are possibilities that these acknowledged
push are discarded.
A simple parameter server, which synchronize model parameters
by pushing gradients and pulling back new parameters, no strict
order is guaranteed.
Warning
DistributedDataParallel is not supported. since we cannot
load state dictionary after creation.
Note
You should initialize PushPullGradServer on all members of
secondary_reducers, and primary_reducer. Both of them
should be members of the group.
Note
Internally the primary reducer will push updated versions
to the ordered server.
Hint
Reduction is performed in a tree fashion:
In the first step, clients will push new gradients to a
random secondary reducer, and the secondary reducer will perform
the first reduction pass, then secondary reducers will push
their results to the primary reducer.
In the second step, the primary reducer will reduce results
from the secondary reducer to get the final reduced gradient
dictionary (has the same structure as state_dict), and assign
gradients to its managed model, and perform the
optimization.
In the final step, the primary reducer will push the final
model to the model server group, then clients can pull the
newest model.
Parameters
server_name (str) – Name of this server, used to registered
the server as a paired class of group.
group (machin.parallel.distributed._world.RpcGroup) – Server group.
model_name (str) – Name of the managed model in the ordered server,
only needed if server needs such a identifier. The default
ordered server does not require this.
primary_reducer (str) – Name of the process serving as the primary reducer,
which collects reduced gradients from secondary reducers and
perform the final reduction.
secondary_reducers (List[str]) – Name of the process serving as secondary
reducers.
o_server (machin.parallel.server.ordered_server.OrderedServerBase) – Custom ordered server accessor. By default, the ordered
server is a OrderedServerSimple hosted on the primary
reducer.
reduce_method (str) – “mean” or “sum”
reduce_device (Union[torch.device, str]) – Device to perform reduction, by default it is “cpu”.
reduce_batch_size (int) – Size of a single reduction batch, server will
wait until the number of requests in the reduction queue have
reached this size.
max_queue_size (int) – Maximum reduction request queue size.
model_name (str) – Name of the managed model in the ordered server,
only needed if server needs such a identifier. The default
ordered server does not require this.
o_server (machin.parallel.server.ordered_server.OrderedServerBase) – Ordered server accessor.
Try to push a model to the ordered server, if failed, the newest
model will be automatically pulled and its parameters will be
assigned to model. Gradients will not be cleared.
Parameters
model (torch.nn.modules.module.Module) – Model to push.
pull_on_fail – Pull the newest parameters if push failed.
Returns
True if push succeeded, else False.
class machin.parallel.server.PushPullModelServerImpl(server_name, group, model_name='model', o_server=None)[source]¶
Bases: object
A simple parameter server, which synchronize model parameters
by pushing and pulling all parameters and maintaining a strict
ordered version chain.
Warning
Only one model is supported.
This init function must be only invoked on the runner process,
and the runner process must be a member process of group.
Parameters
server_name (str) – Name of this server, used to registered
the server as a paired class of group.
group (machin.parallel.distributed._world.RpcGroup) – RpcGroup of the default server OrderedServerSimple
mutually exclusive with o_server
model_name (str) – Name of the managed model in the ordered server,
only needed if server needs such a identifier. The default
ordered server does not require this.
o_server (machin.parallel.server.ordered_server.OrderedServerBase) – Custom ordered server accessor.
Assign models to different devices. In the scope of a single process.
Assigner assumes all GPUs have the same processing power.
Assignment is based on four aspects:
Distance and model connections. Connection is usually indicated
by the amount of data transmitted between two models.
Compute complexity.
Model size.
Entropy.
Four aspects are controlled by four weights:
connection_weight, assigner will try to reduce the total
distance*connection if this weight is larger.
size_match_weight, this weight controls the total memory
space used on a single device, only works if total assigned
memory of models exceeds allowed device memory size
(internally it uses a relu activation), the larger,
the tighter and more restricted the fit.
complexity_match_weight, this weights balance the model
computation cost across devices, assigner will try to even
the computationcost/computepower ratio for each device
if this weight is larger.
entropy_weight, this weight minimize the uncertainty of
model placement probability, so modeli will have a close to 1
probability of locating on some devicej if this weight is
larger.
Assignment uses gradient descent to compute the probability matrix
of each modeli locating on each available devicej.
When the sum of your model size is very close to the capacity of
your device memory, ModelAssigner does not respond very well
to the size_match_weight, therefore, please consider about
increasing model_size_multiplier or decreasing
max_mem_ratio.
Parameters
models (List[torch.nn.modules.module.Module]) – Models to assign.
model_connection (Dict[Tuple[int, int], int]) – Connection weight between modules.
Must be positive
devices (List[Union[torch.device, str]]) – Available devices.
model_size_multiplier – Size multiplier of models, used to reserve
enough space for models,
max_mem_ratio – Maximum percent of memory allowed.
cpu_weight – Weight of cpu. Relative to the computing power of one
GPU. By default it is 0 so no computation will be performed on
CPU. Must be positive
connection_weight – Weight of connection between models.
size_match_weight – Weight of size match.
complexity_match_weight – Weight of complexity match.
entropy_weight – Weight of entropy.
iterations – Number of optimization iterations.
update_rate – Learning rate of the adam optimizer.
gpu_gpu_distance – Estimated distance cost between gpu-gpu.
Must be positive
cpu_gpu_distance – Estimated distance cost between cpu-gpu.
Must be positive
move_models – Whether to automatically move the models after
assignment.
List[t.device]:
Assigned devices for each model in your model list.
class machin.parallel.assigner.ModelSizeEstimator(model, size_multiplier=2)[source]¶
Bases: object
Size estimator for pytorch modules.
Estimates the size of PyTorch models in memory.
Note
This estimator can only estimate the total size of parameters and
buffers. Therefore we need to multiply the raw estimated size with
a correction coefficient to reserve enough space for models.
Parameters
model (torch.nn.modules.module.Module) – Model to be estimated.
size_multiplier – Model estimated size will be
multiplied with this value, to ensure enough space
will be reserved to contain your model and inputs.
class machin.parallel.pickle.Pickler(file, recurse=False, copy_tensor=False)[source]¶
Bases: dill._dill.Pickler
Note
Picklers shares “.dispatch” among instances, and owns
“dispatch_table” per instance.
The base Pickler (not dill, from builtin pickle library),
will first look up the default dump method in “.dispatch”, if
no valid method is found, it will try to find a custom dump
method in “.dispatch_table”.
This takes a binary file for writing a pickle data stream.
The optional protocol argument tells the pickler to use the
given protocol; supported protocols are 0, 1, 2, 3 and 4. The
default protocol is 3; a backward-incompatible protocol designed
for Python 3.
Specifying a negative protocol version selects the highest
protocol version supported. The higher the protocol used, the
more recent the version of Python needed to read the pickle
produced.
The file argument must have a write() method that accepts a
single bytes argument. It can thus be a file object opened for
binary writing, an io.BytesIO instance, or any other custom
object that meets this interface.
If fix_imports is True and protocol is less than 3, pickle
will try to map the new Python 3 names to the old module names
used in Python 2, so that the pickle data stream is readable
with Python 2.
Convert objects to bytes. Works for cpu and gpu tensors.
Warning
Till pytorch 1.5.0, there is a bug for referenced gpu tensors,
which would require users to keep shared gpu tensors during
the whole process life and not reassigning / deleting them,
however, you may refill them with different values.
recurse – Enable recursive dumping, enable this to dump local
functions and lambdas.
copy_tensor – Whether to dump tensors, storage as a full copy.
If it is set to “False”, then dumped tensors must either
locate on GPUs or in shared memory.
Some modules are static, which means they are stateless
and will remain the same whether you import it in process A
or process B.
If your module contains reference to functions, objects
or anything inside a CDLL (usually the reference is a
pointer), it is not picklable by dill, and will cause
nasty errors, however, by marking this module as “Static”,
dill will recognize this module as a builtin module and
not saving the states of this module, dill will only save
a reference to it in this situation.
Parameters
module (Any) – Some module which imports CDLLs by hand and
not using pybind11.
Softly closing the pool, handler threads, and then
shutdown workers by sending signals. The pool will
be closed after all job is finished and all results
returned.
Remember to call join() to wait for full shutdown.
Watch workers for exceptions and raise them and then terminate the pool,
Clean up any retired workers reaching max task number, and
start replacements for them.
Read a result item from the output queue on the pool side.
The method should block for timeout seconds, and then throw
a TimeoutError if no result is available. It should also
throw OSError or EOFError to indicate that it is
improperly closed and cannot be used.
Bring the number of pool workers up to the specified number,
it also creates new workers to replace old workers which have
exited after executing maxtasksperchild.
Like map() method but the elements of the iterable are expected to
be iterables as well and will be unpacked as arguments. Hence
func and (a, b) becomes func(a, b).
Parameters
func (Callable[[Any], Any]) – Function to call.
iterable (Collection[Tuple]) – A collection of tuples of arguments provided to the function call.
chunksize (int) – Size of iterable chunk assigned to each worker.
Returns
A list of result from applying the function on each tuple in the iterable.
Pool with context for each worker. your function must accept a ctx
object as your first non-keyword argument.
If worker_contexts is not specified, then ctx will be None.
The length of worker_contexts must be the same as processes
Note
To share “cpu” tensors in shared memory, you must set:
is_copy_tensor=False,share_method="cpu"
To share “cuda” tensors, you must set:
is_copy_tensor=False,share_method="cuda"
Note
The default context used in pool is “spawn”, to avoid any issues
brought by “fork”. “fork” will only be used if you want to pass
cpu tensors in shared memory.
Parameters
processes (int) – Number of processes in the pool.
initializer – Initializer function executed by the pool/
initargs – Args passed to the init function.
maxtasksperchild – Maximum number of tasks per worker process.
is_recursive – Set to True to support local functions
and lambdas.
is_daemon – Whether worker processes in the pool are started as
daemon processes.
is_copy_tensor – Whether to copy tensors or pass tensors by
reference to worker processes.
share_method – If is_copy_tensor is False, you must
specify this argument. “cpu” means you may use cpu tensors
in the shared memory, “cuda” means cuda tensors, you can only
specify one share method.
Bring the number of pool processes up to the specified number,
for use after reaping workers which have exited.
class machin.parallel.pool.CtxPoolStorage[source]¶
Bases: object
This storage class is used by all CtxPool instances.
However, since for each worker process, they have different
memory spaces, storage is unique for all workers.
To share “cpu” tensors in shared memory, you must set:
is_copy_tensor=False,share_method="cpu"
To share “cuda” tensors, you must set:
is_copy_tensor=False,share_method="cuda"
Note
The default context used in pool is “spawn”, to avoid any issues
brought by “fork”. “fork” will only be used if you want to pass
cpu tensors in shared memory.
Parameters
processes (int) – Number of processes in the pool.
initializer – Initializer function executed by the pool/
initargs – Args passed to the init function.
maxtasksperchild – Maximum number of tasks per worker process.
is_recursive – Set to True to support local functions
and lambdas.
is_daemon – Whether worker processes in the pool are started as
daemon processes.
is_copy_tensor – Whether to copy tensors or pass tensors by
reference to worker processes.
share_method – If is_copy_tensor is False, you must
specify this argument. “cpu” means you may use cpu tensors
in the shared memory, “cuda” means cuda tensors, you can only
specify one share method.
Like map() method but the elements of the iterable are expected to
be iterables as well and will be unpacked as arguments. Hence
func and (a, b) becomes func(a, b).
Parameters
func – Function to call.
iterable – A collection of tuples of arguments provided to the function call.
chunksize – Size of iterable chunk assigned to each worker.
Returns
A list of result from applying the function on each tuple in the iterable.
If timeout is reached and no new item is returned by the worker,
and returned total item number is smaller than the job size,
then raise an TimeoutError.
If total item number is equal than the job size (all jobs finished and
returned), then raise an StopIteration.
To share “cpu” tensors in shared memory, you must set:
is_copy_tensor=False,share_method="cpu"
To share “cuda” tensors, you must set:
is_copy_tensor=False,share_method="cuda"
Note
The default context used in pool is “spawn”, to avoid any issues
brought by “fork”. “fork” will only be used if you want to pass
cpu tensors in shared memory.
Parameters
processes – Number of processes in the pool.
initializer – Initializer function executed by the pool/
initargs – Args passed to the init function.
maxtasksperchild – Maximum number of tasks per worker process.
is_recursive – Set to True to support local functions
and lambdas.
is_daemon – Whether worker processes in the pool are started as
daemon processes.
is_copy_tensor – Whether to copy tensors or pass tensors by
reference to worker processes.
share_method – If is_copy_tensor is False, you must
specify this argument. “cpu” means you may use cpu tensors
in the shared memory, “cuda” means cuda tensors, you can only
specify one share method.
Softly closing the pool, handler threads, and then
shutdown workers by sending signals. The pool will
be closed after all job is finished and all results
returned.
Remember to call join() to wait for full shutdown.
Bring the number of pool workers up to the specified number,
it also creates new workers to replace old workers which have
exited after executing maxtasksperchild.
Enhanced multiprocessing pool for pytorch, provides:
Support for lambdas and local functions.
Ability to select the tensor serialize scheme.
Note
To share “cpu” tensors in shared memory, you must set:
is_copy_tensor=False,share_method="cpu"
To share “cuda” tensors, you must set:
is_copy_tensor=False,share_method="cuda"
Note
The default context used in pool is “spawn”, to avoid any issues
brought by “fork”. “fork” will only be used if you want to pass
cpu tensors in shared memory.
Parameters
processes – Number of processes in the pool.
initializer – Initializer function executed by the pool/
initargs – Args passed to the init function.
maxtasksperchild – Maximum number of tasks per worker process.
is_recursive – Set to True to support local functions
and lambdas.
is_daemon – Whether worker processes in the pool are started as
daemon processes.
is_copy_tensor – Whether to copy tensors or pass tensors by
reference to worker processes.
share_method – If is_copy_tensor is False, you must
specify this argument. “cpu” means you may use cpu tensors
in the shared memory, “cuda” means cuda tensors, you can only
specify one share method.
Like map() method but the elements of the iterable are expected to
be iterables as well and will be unpacked as arguments. Hence
func and (a, b) becomes func(a, b).
Parameters
func – Function to call.
iterable – A collection of tuples of arguments provided to the function call.
chunksize – Size of iterable chunk assigned to each worker.
Returns
A list of result from applying the function on each tuple in the iterable.
To share “cpu” tensors in shared memory, you must set:
is_copy_tensor=False,share_method="cpu"
To share “cuda” tensors, you must set:
is_copy_tensor=False,share_method="cuda"
Note
The default context used in pool is “spawn”, to avoid any issues
brought by “fork”. “fork” will only be used if you want to pass
cpu tensors in shared memory.
Parameters
processes – Number of processes in the pool.
initializer – Initializer function executed by the pool/
initargs – Args passed to the init function.
maxtasksperchild – Maximum number of tasks per worker process.
is_recursive – Set to True to support local functions
and lambdas.
is_daemon – Whether worker processes in the pool are started as
daemon processes.
is_copy_tensor – Whether to copy tensors or pass tensors by
reference to worker processes.
share_method – If is_copy_tensor is False, you must
specify this argument. “cpu” means you may use cpu tensors
in the shared memory, “cuda” means cuda tensors, you can only
specify one share method.
Watch workers for exceptions and raise them and then terminate the pool,
Clean up any retired workers reaching max task number, and
start replacements for them.
Read a result item from the output queue on the pool side.
The method should block for timeout seconds, and then throw
a TimeoutError if no result is available. It should also
throw OSError or EOFError to indicate that it is
improperly closed and cannot be used.
Bring the number of pool workers up to the specified number,
it also creates new workers to replace old workers which have
exited after executing maxtasksperchild.
this api is used by the result manager (Pool._result_handler)
thread to get results from the queue, since it is single threaded,
there is no need to use locks, and therefore quicker.