Source code for machin.frame.algorithms.gail

from typing import Union, Dict, List, Tuple, Callable, Any
from copy import deepcopy
import os
import torch as t
import torch.nn as nn
import numpy as np

from machin.model.nets.base import NeuralNetworkModule
from machin.frame.buffers.buffer import TransitionBase, Transition, Buffer
from .ppo import PPO
from .trpo import TRPO
from .base import TorchFramework, Config
from .utils import (
    safe_call,
    assert_and_get_valid_models,
    assert_and_get_valid_optimizer,
    assert_and_get_valid_lr_scheduler,
)


[docs]class ExpertTransition(TransitionBase): """ The ExpertTransition class for expert steps. Have two main attributes: ``state`` and ``action``. """ # for auto suggestion in IDEs state = None # type: Dict[str, t.Tensor] action = None # type: Dict[str, t.Tensor] def __init__( self, state: Dict[str, t.Tensor], action: Dict[str, t.Tensor], ): """ Args: state: Previous observed state. action: Action of expert. """ super().__init__( major_attr=["state", "action"], sub_attr=[], custom_attr=[], major_data=[state, action], sub_data=[], custom_data=[], ) def _check_validity(self): # fix batch size to 1 super()._check_validity() if self._batch_size != 1: raise ValueError( "Batch size of the expert transition " f"implementation must be 1, is {self._batch_size}" )
[docs]class GAIL(TorchFramework): """ GAIL framework. """ _is_top = ["actor", "critic", "discriminator"] _is_restorable = ["actor", "critic", "discriminator"] def __init__( self, discriminator: Union[NeuralNetworkModule, nn.Module], constrained_policy_optimization: Union[PPO, TRPO], optimizer: Callable, *_, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple] = None, lr_scheduler_kwargs: Tuple[Dict] = None, batch_size: int = 100, discriminator_update_times: int = 1, discriminator_learning_rate: float = 0.001, gradient_max: float = np.inf, expert_replay_size: int = 500000, expert_replay_device: Union[str, t.device] = "cpu", expert_replay_buffer: Buffer = None, visualize: bool = False, visualize_dir: str = "", **__, ): """ Note: The forward method of your discriminator network must take two arguments:: def forward(self, state: Dict[str, t.Tensor], action: Dict[str, t.Tensor]) And return a tag vector (float type) of size ``[batch_size, 1]``, usually you can do this by using a sigmoid output layer. If you set ``concatenate_samples`` to ``False`` during the ``update()`` call, then you should expect ``Dict[str, List[t.Tensor]]``. Note: You can access the following attributes: 1. ``actor`` 2. ``critic`` 3. ``actor_optim`` 4. ``critic_optim`` 5. ``actor_lr_sch`` (may not exist if you are not using lr scheduler) 6. ``critic_lr_sch`` (may not exist if you are not using lr scheduler) 7. ``replay_buffer`` of the input PPO or TRPO framework directly from an GAIL instance, since they are forwarded to here. For other attributes, you need to manually access them from the ``constrained_policy_optimization`` attribute. Args: discriminator: Discriminator network module. constrained_policy_optimization: A constrained policy optimization framework, currently can be a :class:`.PPO` or :class:`TRPO` framework. optimizer: Optimizer used to optimize ``discriminator``. discriminator_learning_rate: Learning rate of the discriminator optimizer, not compatible with ``lr_scheduler``. lr_scheduler: Learning rate scheduler of ``optimizer``. lr_scheduler_args: Arguments of the learning rate scheduler. lr_scheduler_kwargs: Keyword arguments of the learning rate scheduler. batch_size: Batch size used during discriminator training. gradient_max: Maximum gradient. expert_replay_size: Expert trajectory buffer size. Not compatible with ``expert_replay_buffer``. expert_replay_device: Device where the expert replay buffer locates on, Not compatible with ``expert_replay_buffer``. expert_replay_buffer: Custom expert replay buffer. visualize: Whether visualize the network flow in the first pass. visualize_dir: Visualized graph save directory. """ self.batch_size = batch_size self.gradient_max = gradient_max self.visualize = visualize self.visualize_dir = visualize_dir self.constrained_policy_optimization = constrained_policy_optimization self.actor = constrained_policy_optimization.actor self.critic = constrained_policy_optimization.critic self.actor_optim = constrained_policy_optimization.actor_optim self.critic_optim = constrained_policy_optimization.critic_optim if hasattr(constrained_policy_optimization, "actor_lr_sch"): self.actor_lr_sch = constrained_policy_optimization.actor_lr_sch if hasattr(constrained_policy_optimization, "critic_lr_sch"): self.critic_lr_sch = constrained_policy_optimization.critic_lr_sch self.replay_buffer = constrained_policy_optimization.replay_buffer self.discriminator = discriminator # By default it is BCELoss, you can modify this attribute. self.discriminator_criterion = nn.BCELoss() self.discriminator_update_times = discriminator_update_times self.discriminator_optim = optimizer( self.actor.parameters(), lr=discriminator_learning_rate ) self.expert_replay_buffer = ( Buffer(expert_replay_size, expert_replay_device) if expert_replay_buffer is None else expert_replay_buffer ) if lr_scheduler is not None: if lr_scheduler_args is None: lr_scheduler_args = ((),) if lr_scheduler_kwargs is None: lr_scheduler_kwargs = ({},) self.discriminator_lr_sch = lr_scheduler( self.discriminator_optim, *lr_scheduler_args[0], **lr_scheduler_kwargs[0], ) self.bce_criterion = nn.BCELoss() super().__init__() @property def optimizers(self): return [self.actor_optim, self.critic_optim, self.discriminator_optim] @optimizers.setter def optimizers(self, optimizers): self.actor_optim, self.critic_optim, self.discriminator_optim = optimizers @property def lr_schedulers(self): lr_schs = [] if hasattr(self, "actor_lr_sch") and hasattr(self, "critic_lr_sch"): lr_schs += [self.actor_lr_sch, self.critic_lr_sch] if hasattr(self, "discriminator_lr_sch"): lr_schs += [self.discriminator_lr_sch] return lr_schs
[docs] def act(self, state: Dict[str, Any], *_, **__): """ Use actor network to give a policy to the current state. Returns: Anything produced by actor. """ # No need to safe_return because the number of # returned values is always more than one return safe_call(self.actor, state)
def _discriminate(self, state: Dict[str, Any], action: Dict[str, Any], *_, **__): """ Use discriminator network to assign a real (0) / fake (1) tag to state-action pairs. Returns: Tags of shape ``[batch_size, 1]`` """ return safe_call(self.discriminator, state, action)[0]
[docs] def store_episode(self, episode: List[Union[Transition, Dict]]): """ Add a full episode of transition samples to the replay buffer. "value" and "gae" are automatically calculated. """ # replace reward with expert reward # don't do pre-concatenation since state size may be uneven. for trans in episode: trans["reward"] = -np.log(self._discriminate(**trans).item()) self.constrained_policy_optimization.store_episode(episode)
[docs] def store_expert_episode(self, episode: List[Union[ExpertTransition, Dict]]): """ Add a full episode of transition samples from the expert trajectory to the replay buffer. Only states and actions are required. """ episode = [ ExpertTransition(**trans) for trans in episode if isinstance(trans, dict) ] self.expert_replay_buffer.store_episode( episode, required_attrs=("state", "action") )
[docs] def update( self, update_value=True, update_policy=True, update_discriminator=True, concatenate_samples=True, **__, ): """ Update network weights by sampling from buffer. Buffer will be cleared after update is finished. Args: update_value: Whether update the Q network. update_policy: Whether update the actor network. update_discriminator: Whether update the discriminator network. concatenate_samples: Whether concatenate the samples. Returns: mean value of estimated policy value, value loss """ self.actor.train() self.critic.train() self.discriminator.train() sum_discrim_loss = 0 for _ in range(self.discriminator_update_times): # sample a batch from expert buffer and a batch from training buffer. e_batch_size, (e_state, e_action) = self.expert_replay_buffer.sample_batch( self.batch_size, sample_method="random_unique", concatenate=concatenate_samples, sample_attrs=["state", "action"], ) exp_out = self._discriminate(e_state, e_action) batch_size, (state, action) = self.replay_buffer.sample_batch( self.batch_size, sample_method="random_unique", concatenate=concatenate_samples, sample_attrs=["state", "action"], ) gen_out = self._discriminate(state, action) discrim_loss = self.discriminator_criterion( gen_out, t.ones_like(gen_out) ) + self.discriminator_criterion(exp_out, t.zeros_like(exp_out)) # Update discriminator network if update_discriminator: self.discriminator.zero_grad() self._backward(discrim_loss) nn.utils.clip_grad_norm_( self.discriminator.parameters(), self.gradient_max ) self.discriminator_optim.step() if self.visualize: self.visualize_model(discrim_loss, "discriminator", self.visualize_dir) sum_discrim_loss += discrim_loss.item() # perform mini-batch PPO or TRPO update act_loss, value_loss = self.constrained_policy_optimization.update( update_value=update_value, update_policy=update_policy, concatenate_samples=concatenate_samples, ) self.actor.eval() self.critic.eval() self.discriminator.eval() return act_loss, value_loss, sum_discrim_loss / self.discriminator_update_times
[docs] def update_lr_scheduler(self): """ Update learning rate schedulers. """ if hasattr(self, "actor_lr_sch"): self.actor_lr_sch.step() if hasattr(self, "critic_lr_sch"): self.critic_lr_sch.step() if hasattr(self, "discriminator_lr_sch"): self.discriminator_lr_sch.step()
[docs] @classmethod def generate_config(cls, config: Union[Dict[str, Any], Config]): default_values = { "models": ["Discriminator"], "model_args": ((),), "model_kwargs": ({},), "constrained_policy_optimization": "PPO", "optimizer": "Adam", "PPO_config": PPO.generate_config({}), "TRPO_config": TRPO.generate_config({}), "lr_scheduler": None, "lr_scheduler_args": None, "lr_scheduler_kwargs": None, "batch_size": 100, "discriminator_update_times": 1, "discriminator_learning_rate": 0.001, "gradient_max": np.inf, "expert_trajectory_path": "trajectory.data", "expert_replay_size": 500000, "expert_replay_device": "cpu", "expert_replay_buffer": None, "visualize": False, "visualize_dir": "", } config = deepcopy(config) config["frame"] = "GAIL" if "frame_config" not in config: config["frame_config"] = default_values else: config["frame_config"] = {**config["frame_config"], **default_values} return config
[docs] @classmethod def init_from_config( cls, config: Union[Dict[str, Any], Config], model_device: Union[str, t.device] = "cpu", ): f_config = deepcopy(config["frame_config"]) # Initialize PPO or TRPO first if f_config["constrained_policy_optimization"] == "PPO": cpo = PPO.init_from_config(f_config["PPO_config"], model_device) elif f_config["constrained_policy_optimization"] == "TRPO": cpo = TRPO.init_from_config(f_config["TRPO_config"], model_device) else: raise ValueError("constrained_policy_optimization must be PPO or TRPO.") discrim_model = assert_and_get_valid_models(f_config["models"])[0] discrim_model_args = f_config["model_args"][0] discrim_model_kwargs = f_config["model_kwargs"][0] discrim = discrim_model(*discrim_model_args, **discrim_model_kwargs) optimizer = assert_and_get_valid_optimizer(f_config["optimizer"]) lr_scheduler = f_config["lr_scheduler"] and assert_and_get_valid_lr_scheduler( f_config["lr_scheduler"] ) f_config["constrained_policy_optimization"] = cpo f_config["optimizer"] = optimizer f_config["lr_scheduler"] = lr_scheduler frame = cls(discrim, **f_config) if os.path.isfile(f_config["expert_trajectory_path"]): trajectory_list = t.load(f_config["expert_trajectory_path"]) for trajectory in trajectory_list: frame.store_expert_episode(trajectory) return frame