# pylint: disable=wildcard-import, unused-wildcard-import
from .ddpg import *
[docs]class TD3(DDPG):
"""
TD3 framework. Which adds a additional pair of critic and target critic
network to DDPG.
"""
_is_top = [
"actor",
"critic",
"critic2",
"actor_target",
"critic_target",
"critic2_target",
]
_is_restorable = ["actor_target", "critic_target", "critic2_target"]
def __init__(
self,
actor: Union[NeuralNetworkModule, nn.Module],
actor_target: Union[NeuralNetworkModule, nn.Module],
critic: Union[NeuralNetworkModule, nn.Module],
critic_target: Union[NeuralNetworkModule, nn.Module],
critic2: Union[NeuralNetworkModule, nn.Module],
critic2_target: Union[NeuralNetworkModule, nn.Module],
optimizer: Callable,
criterion: Callable,
*_,
lr_scheduler: Callable = None,
lr_scheduler_args: Tuple[Tuple, Tuple, Tuple] = None,
lr_scheduler_kwargs: Tuple[Dict, Dict, Dict] = None,
batch_size: int = 100,
update_rate: float = 0.001,
update_steps: Union[int, None] = None,
actor_learning_rate: float = 0.0005,
critic_learning_rate: float = 0.001,
discount: float = 0.99,
gradient_max: float = np.inf,
replay_size: int = 500000,
replay_device: Union[str, t.device] = "cpu",
replay_buffer: Buffer = None,
visualize: bool = False,
visualize_dir: str = "",
**__
):
"""
See Also:
:class:`.DDPG`
Args:
actor: Actor network module.
actor_target: Target actor network module.
critic: Critic network module.
critic_target: Target critic network module.
critic2: The second critic network module.
critic2_target: The second target critic network module.
optimizer: Optimizer used to optimize ``actor``, ``critic``,
criterion: Criterion used to evaluate the value loss.
lr_scheduler: Learning rate scheduler of ``optimizer``.
lr_scheduler_args: Arguments of the learning rate scheduler.
lr_scheduler_kwargs: Keyword arguments of the learning
rate scheduler.
batch_size: Batch size used during training.
update_rate: :math:`\\tau` used to update target networks.
Target parameters are updated as:
:math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)`
update_steps: Training step number used to update target networks.
actor_learning_rate: Learning rate of the actor optimizer,
not compatible with ``lr_scheduler``.
critic_learning_rate: Learning rate of the critic optimizer,
not compatible with ``lr_scheduler``.
discount: :math:`\\gamma` used in the bellman function.
replay_size: Replay buffer size. Not compatible with
``replay_buffer``.
replay_device: Device where the replay buffer locates on, Not
compatible with ``replay_buffer``.
replay_buffer: Custom replay buffer.
visualize: Whether visualize the network flow in the first pass.
visualize_dir: Visualized graph save directory.
"""
if lr_scheduler_args is None:
lr_scheduler_args = ((), (), ())
if lr_scheduler_kwargs is None:
lr_scheduler_kwargs = ({}, {}, {})
super().__init__(
actor,
actor_target,
critic,
critic_target,
optimizer,
criterion,
lr_scheduler=lr_scheduler,
lr_scheduler_args=(
lr_scheduler_args[:2] if lr_scheduler_args is not None else None
),
lr_scheduler_kwargs=(
lr_scheduler_kwargs[:2] if lr_scheduler_kwargs is not None else None
),
batch_size=batch_size,
update_rate=update_rate,
update_steps=update_steps,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
discount=discount,
gradient_max=gradient_max,
replay_size=replay_size,
replay_device=replay_device,
replay_buffer=replay_buffer,
visualize=visualize,
visualize_dir=visualize_dir,
)
self.critic2 = critic2
self.critic2_target = critic2_target
self.critic2_optim = optimizer(
self.critic2.parameters(), lr=critic_learning_rate
)
# Make sure target and online networks have the same weight
with t.no_grad():
hard_update(self.critic2, self.critic2_target)
if lr_scheduler is not None:
self.critic2_lr_sch = lr_scheduler(
self.critic2_optim, *lr_scheduler_args[2], **lr_scheduler_kwargs[2]
)
@property
def optimizers(self):
return [self.actor_optim, self.critic_optim, self.critic2_optim]
@optimizers.setter
def optimizers(self, optimizers):
self.actor_optim, self.critic_optim, self.critic2_optim = optimizers
@property
def lr_schedulers(self):
if (
hasattr(self, "actor_lr_sch")
and hasattr(self, "critic_lr_sch")
and hasattr(self, "critic2_lr_sch")
):
return [self.actor_lr_sch, self.critic_lr_sch, self.critic2_lr_sch]
return []
def _criticize2(
self, state: Dict[str, Any], action: Dict[str, Any], use_target=False, **__
):
"""
Use the second critic network to evaluate current value.
Args:
state: Current state.
action: Current action.
use_target: Whether to use the target network.
Returns:
Q Value of shape ``[batch_size, 1]``.
"""
if use_target:
return safe_call(self.critic2_target, state, action)[0]
else:
return safe_call(self.critic2, state, action)[0]
[docs] def update(
self,
update_value=True,
update_policy=True,
update_target=True,
concatenate_samples=True,
**__
):
# DOC INHERITED
self.actor.train()
self.critic.train()
self.critic2.train()
(
batch_size,
(state, action, reward, next_state, terminal, others,),
) = self.replay_buffer.sample_batch(
self.batch_size,
concatenate_samples,
sample_method="random_unique",
sample_attrs=["state", "action", "reward", "next_state", "terminal", "*"],
)
# Update critic network first.
# Generate value reference :math: `y_i` using target actor and
# target critic.
with t.no_grad():
next_action = self.action_transform_function(
self.policy_noise_function(self._act(next_state, True)),
next_state,
others,
)
next_value = self._criticize(next_state, next_action, True)
next_value2 = self._criticize2(next_state, next_action, True)
next_value = t.min(next_value, next_value2)
next_value = next_value.view(batch_size, -1)
y_i = self.reward_function(
reward, self.discount, next_value, terminal, others
)
cur_value = self._criticize(state, action)
cur_value2 = self._criticize2(state, action)
value_loss = self.criterion(cur_value, y_i.type_as(cur_value))
value_loss2 = self.criterion(cur_value2, y_i.type_as(cur_value))
if self.visualize:
self.visualize_model(value_loss, "critic", self.visualize_dir)
if update_value:
self.critic.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), self.gradient_max)
self.critic_optim.step()
self.critic2.zero_grad()
value_loss2.backward()
nn.utils.clip_grad_norm_(self.critic2.parameters(), self.gradient_max)
self.critic2_optim.step()
# Update actor network
cur_action = self.action_transform_function(self._act(state), state, others)
act_value = self._criticize(state, cur_action)
# "-" is applied because we want to maximize J_b(u),
# but optimizer workers by minimizing the target
act_policy_loss = -act_value.mean()
if self.visualize:
self.visualize_model(act_policy_loss, "actor", self.visualize_dir)
if update_policy:
self.actor.zero_grad()
act_policy_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), self.gradient_max)
self.actor_optim.step()
# Update target networks
if update_target:
if self.update_rate is not None:
soft_update(self.actor_target, self.actor, self.update_rate)
soft_update(self.critic_target, self.critic, self.update_rate)
soft_update(self.critic2_target, self.critic2, self.update_rate)
else:
self._update_counter += 1
if self._update_counter % self.update_steps == 0:
hard_update(self.actor_target, self.actor)
hard_update(self.critic_target, self.critic)
hard_update(self.critic2_target, self.critic2)
self.actor.eval()
self.critic.eval()
self.critic2.eval()
# use .item() to prevent memory leakage
return -act_policy_loss.item(), (value_loss.item() + value_loss2.item()) / 2
[docs] @staticmethod
def policy_noise_function(actions, *_):
# Function used to add noise to actions, mentioned in TD3
# training tricks
return actions
[docs] def update_lr_scheduler(self):
"""
Update learning rate schedulers.
"""
if hasattr(self, "critic2_lr_sch"):
self.critic2_lr_sch.step()
super().update_lr_scheduler()
[docs] def load(
self, model_dir: str, network_map: Dict[str, str] = None, version: int = -1
):
# DOC INHERITED
TorchFramework.load(self, model_dir, network_map, version)
with t.no_grad():
hard_update(self.actor, self.actor_target)
hard_update(self.critic, self.critic_target)
hard_update(self.critic2, self.critic2_target)
[docs] @classmethod
def generate_config(cls, config: Union[Dict[str, Any], Config]):
config = DDPG.generate_config(config)
config["frame"] = "TD3"
config["frame_config"]["models"] = [
"Actor",
"Actor",
"Critic",
"Critic",
"Critic",
"Critic",
]
config["frame_config"]["model_args"] = ((), (), (), (), (), ())
config["frame_config"]["model_kwargs"] = ({}, {}, {}, {}, {}, {})
return config