from typing import Union, Dict, List, Tuple, Callable, Any
from torch.distributions import Categorical
import torch as t
import torch.nn as nn
import numpy as np
from machin.frame.buffers.buffer import Transition, Buffer
from machin.frame.noise.action_space_noise import \
add_normal_noise_to_action, \
add_clipped_normal_noise_to_action, \
add_uniform_noise_to_action, \
add_ou_noise_to_action
from machin.model.nets.base import NeuralNetworkModule
from .base import TorchFramework
from .utils import \
hard_update, \
soft_update, \
safe_call, \
safe_return, \
assert_output_is_probs
[docs]class DDPG(TorchFramework):
"""
DDPG framework.
"""
_is_top = ["actor", "critic", "actor_target", "critic_target"]
_is_restorable = ["actor_target", "critic_target"]
def __init__(self,
actor: Union[NeuralNetworkModule, nn.Module],
actor_target: Union[NeuralNetworkModule, nn.Module],
critic: Union[NeuralNetworkModule, nn.Module],
critic_target: Union[NeuralNetworkModule, nn.Module],
optimizer: Callable,
criterion: Callable,
*_,
lr_scheduler: Callable = None,
lr_scheduler_args: Tuple[Tuple, Tuple] = None,
lr_scheduler_kwargs: Tuple[Dict, Dict] = None,
batch_size: int = 100,
update_rate: float = 0.001,
actor_learning_rate: float = 0.0005,
critic_learning_rate: float = 0.001,
discount: float = 0.99,
gradient_max: float = np.inf,
replay_size: int = 500000,
replay_device: Union[str, t.device] = "cpu",
replay_buffer: Buffer = None,
visualize: bool = False,
visualize_dir: str = "",
**__):
"""
Note:
Your optimizer will be called as::
optimizer(network.parameters(), learning_rate)
Your lr_scheduler will be called as::
lr_scheduler(
optimizer,
*lr_scheduler_args[0],
**lr_scheduler_kwargs[0],
)
Your criterion will be called as::
criterion(
target_value.view(batch_size, 1),
predicted_value.view(batch_size, 1)
)
Args:
actor: Actor network module.
actor_target: Target actor network module.
critic: Critic network module.
critic_target: Target critic network module.
optimizer: Optimizer used to optimize ``actor`` and ``critic``.
criterion: Criterion used to evaluate the value loss.
lr_scheduler: Learning rate scheduler of ``optimizer``.
lr_scheduler_args: Arguments of the learning rate scheduler.
lr_scheduler_kwargs: Keyword arguments of the learning
rate scheduler.
batch_size: Batch size used during training.
update_rate: :math:`\\tau` used to update target networks.
Target parameters are updated as:
:math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)`
actor_learning_rate: Learning rate of the actor optimizer,
not compatible with ``lr_scheduler``.
critic_learning_rate: Learning rate of the critic optimizer,
not compatible with ``lr_scheduler``.
discount: :math:`\\gamma` used in the bellman function.
replay_size: Replay buffer size. Not compatible with
``replay_buffer``.
replay_device: Device where the replay buffer locates on, Not
compatible with ``replay_buffer``.
replay_buffer: Custom replay buffer.
visualize: Whether visualize the network flow in the first pass.
visualize_dir: Visualized graph save directory.
"""
self.batch_size = batch_size
self.update_rate = update_rate
self.discount = discount
self.grad_max = gradient_max
self.visualize = visualize
self.visualize_dir = visualize_dir
self.actor = actor
self.actor_target = actor_target
self.critic = critic
self.critic_target = critic_target
self.actor_optim = optimizer(self.actor.parameters(),
lr=actor_learning_rate)
self.critic_optim = optimizer(self.critic.parameters(),
lr=critic_learning_rate)
self.replay_buffer = (Buffer(replay_size, replay_device)
if replay_buffer is None
else replay_buffer)
# Make sure target and online networks have the same weight
with t.no_grad():
hard_update(self.actor, self.actor_target)
hard_update(self.critic, self.critic_target)
if lr_scheduler is not None:
if lr_scheduler_args is None:
lr_scheduler_args = ((), ())
if lr_scheduler_kwargs is None:
lr_scheduler_kwargs = ({}, {})
self.actor_lr_sch = lr_scheduler(
self.actor_optim,
*lr_scheduler_args[0],
**lr_scheduler_kwargs[0]
)
self.critic_lr_sch = lr_scheduler(
self.critic_optim,
*lr_scheduler_args[1],
**lr_scheduler_kwargs[1]
)
self.criterion = criterion
super(DDPG, self).__init__()
[docs] def act(self,
state: Dict[str, Any],
use_target: bool = False,
**__):
"""
Use actor network to produce an action for the current state.
Args:
state: Current state.
use_target: Whether use the target network.
Returns:
Any thing returned by your actor network.
"""
if use_target:
return safe_return(safe_call(self.actor_target, state))
else:
return safe_return(safe_call(self.actor, state))
[docs] def act_with_noise(self,
state: Dict[str, Any],
noise_param: Any = (0.0, 1.0),
ratio: float = 1.0,
mode: str = "uniform",
use_target: bool = False,
**__):
"""
Use actor network to produce a noisy action for the current state.
See Also:
:mod:`machin.frame.noise.action_space_noise`
Args:
state: Current state.
noise_param: Noise params.
ratio: Noise ratio.
mode: Noise mode. Supported are:
``"uniform", "normal", "clipped_normal", "ou"``
use_target: Whether use the target network.
Returns:
Noisy action of shape ``[batch_size, action_dim]``.
Any other things returned by your actor network. if they exist.
"""
if use_target:
action, *others = safe_call(self.actor_target, state)
else:
action, *others = safe_call(self.actor, state)
if mode == "uniform":
noisy_action = add_uniform_noise_to_action(
action, noise_param, ratio
)
elif mode == "normal":
noisy_action = add_normal_noise_to_action(
action, noise_param, ratio
)
elif mode == "clipped_normal":
noisy_action = add_clipped_normal_noise_to_action(
action, noise_param, ratio
)
elif mode == "ou":
noisy_action = add_ou_noise_to_action(
action, noise_param, ratio
)
else:
raise ValueError("Unknown noise type: " + str(mode))
if len(others) == 0:
return noisy_action
else:
return (noisy_action, *others)
[docs] def act_discrete(self,
state: Dict[str, Any],
use_target: bool = False,
**__):
"""
Use actor network to produce a discrete action for the current state.
Notes:
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Args:
state: Current state.
use_target: Whether to use the target network.
Returns:
Action of shape ``[batch_size, 1]``.
Action probability tensor of shape ``[batch_size, action_num]``,
produced by your actor.
Any other things returned by your Q network. if they exist.
"""
if use_target:
action, *others = safe_call(self.actor_target, state)
else:
action, *others = safe_call(self.actor, state)
assert_output_is_probs(action)
batch_size = action.shape[0]
result = t.argmax(action, dim=1).view(batch_size, 1)
return (result, action, *others)
[docs] def act_discrete_with_noise(self,
state: Dict[str, Any],
use_target: bool = False,
**__):
"""
Use actor network to produce a noisy discrete action for
the current state.
Notes:
actor network must output a probability tensor, of shape
(batch_size, action_dims), and has a sum of 1 for each row
in dimension 1.
Args:
state: Current state.
use_target: Whether to use the target network.
Returns:
Noisy action of shape ``[batch_size, 1]``.
Action probability tensor of shape ``[batch_size, action_num]``.
Any other things returned by your Q network. if they exist.
"""
if use_target:
action, *others = safe_call(self.actor_target, state)
else:
action, *others = safe_call(self.actor, state)
assert_output_is_probs(action)
dist = Categorical(action)
batch_size = action.shape[0]
return (dist.sample([batch_size, 1]).view(batch_size, 1),
*others)
def _act(self,
state: Dict[str, Any],
use_target: bool = False,
**__):
"""
Use actor network to produce an action for the current state.
Args:
state: Current state.
use_target: Whether use the target network.
Returns:
Action of shape ``[batch_size, action_dim]``.
"""
if use_target:
return safe_call(self.actor_target, state)[0]
else:
return safe_call(self.actor, state)[0]
def _criticize(self,
state: Dict[str, Any],
action: Dict[str, Any],
use_target: bool = False,
**__):
"""
Use critic network to evaluate current value.
Args:
state: Current state.
action: Current action.
use_target: Whether to use the target network.
Returns:
Q Value of shape ``[batch_size, 1]``.
"""
if use_target:
return safe_call(self.critic_target, state, action)[0]
else:
return safe_call(self.critic, state, action)[0]
[docs] def store_transition(self, transition: Union[Transition, Dict]):
"""
Add a transition sample to the replay buffer.
"""
self.replay_buffer.append(transition, required_attrs=(
"state", "action", "reward", "next_state", "terminal"
))
[docs] def store_episode(self, episode: List[Union[Transition, Dict]]):
"""
Add a full episode of transition samples to the replay buffer.
"""
for trans in episode:
self.replay_buffer.append(trans, required_attrs=(
"state", "action", "reward", "next_state", "terminal"
))
[docs] def update(self,
update_value=True,
update_policy=True,
update_target=True,
concatenate_samples=True,
**__):
"""
Update network weights by sampling from replay buffer.
Args:
update_value: Whether to update the Q network.
update_policy: Whether to update the actor network.
update_target: Whether to update targets.
concatenate_samples: Whether to concatenate the samples.
Returns:
mean value of estimated policy value, value loss
"""
self.actor.train()
self.critic.train()
batch_size, (state, action, reward, next_state, terminal, others) = \
self.replay_buffer.sample_batch(self.batch_size,
concatenate_samples,
sample_method="random_unique",
sample_attrs=[
"state", "action",
"reward", "next_state",
"terminal", "*"
])
# Update critic network first.
# Generate value reference :math: `y_i` using target actor and
# target critic.
with t.no_grad():
next_action = self.action_transform_function(
self._act(next_state, True), next_state, others
)
next_value = self._criticize(next_state, next_action, True)
next_value = next_value.view(batch_size, -1)
y_i = self.reward_function(
reward, self.discount, next_value, terminal, others
)
cur_value = self._criticize(state, action)
value_loss = self.criterion(cur_value, y_i.to(cur_value.device))
if self.visualize:
self.visualize_model(value_loss, "critic", self.visualize_dir)
if update_value:
self.critic.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(
self.critic.parameters(), self.grad_max
)
self.critic_optim.step()
# Update actor network
cur_action = self.action_transform_function(
self._act(state), state, others
)
act_value, *_ = self._criticize(state, cur_action)
# "-" is applied because we want to maximize J_b(u),
# but optimizer workers by minimizing the target
act_policy_loss = -act_value.mean()
if self.visualize:
self.visualize_model(act_policy_loss, "actor", self.visualize_dir)
if update_policy:
self.actor.zero_grad()
act_policy_loss.backward()
nn.utils.clip_grad_norm_(
self.actor.parameters(), self.grad_max
)
self.actor_optim.step()
# Update target networks
if update_target:
soft_update(self.actor_target, self.actor, self.update_rate)
soft_update(self.critic_target, self.critic, self.update_rate)
self.actor.eval()
self.critic.eval()
# use .item() to prevent memory leakage
return -act_policy_loss.item(), value_loss.item()
[docs] def update_lr_scheduler(self):
"""
Update learning rate schedulers.
"""
if hasattr(self, "actor_lr_sch"):
self.actor_lr_sch.step()
if hasattr(self, "critic_lr_sch"):
self.critic_lr_sch.step()
[docs] def load(self, model_dir: str, network_map: Dict[str, str] = None,
version: int = -1):
# DOC INHERITED
super(DDPG, self).load(model_dir, network_map, version)
with t.no_grad():
hard_update(self.actor, self.actor_target)
hard_update(self.critic, self.critic_target)
[docs] @staticmethod
def reward_function(reward, discount, next_value, terminal, _):
next_value = next_value.to(reward.device)
terminal = terminal.to(reward.device)
return reward + discount * ~terminal * next_value