Callbacks

A callback is a set of functions that will be called at given stages of the training procedure. You can use callbacks to access internal state of the RL model during training. It allows one to do monitoring, auto saving, model manipulation, progress bars, …

Custom Callback

To build a custom callback, you need to create a class that derives from BaseCallback. This will give you access to events (_on_training_start, _on_step) and useful variables (like self.model for the RL model).

You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see Examples), and one for logging additional values with Tensorboard (see Tensorboard section).

from stable_baselines.common.callbacks import BaseCallback


class CustomCallback(BaseCallback):
    """
    A custom callback that derives from ``BaseCallback``.

    :param verbose: (int) Verbosity level 0: not output 1: info 2: debug
    """
    def __init__(self, verbose=0):
        super(CustomCallback, self).__init__(verbose)
        # Those variables will be accessible in the callback
        # (they are defined in the base class)
        # The RL model
        # self.model = None  # type: BaseRLModel
        # An alias for self.model.get_env(), the environment used for training
        # self.training_env = None  # type: Union[gym.Env, VecEnv, None]
        # Number of time the callback was called
        # self.n_calls = 0  # type: int
        # self.num_timesteps = 0  # type: int
        # local and global variables
        # self.locals = None  # type: Dict[str, Any]
        # self.globals = None  # type: Dict[str, Any]
        # The logger object, used to report things in the terminal
        # self.logger = None  # type: logger.Logger
        # # Sometimes, for event callback, it is useful
        # # to have access to the parent object
        # self.parent = None  # type: Optional[BaseCallback]

    def _on_training_start(self) -> None:
        """
        This method is called before the first rollout starts.
        """
        pass

    def _on_rollout_start(self) -> None:
        """
        A rollout is the collection of environment interaction
        using the current policy.
        This event is triggered before collecting new samples.
        """
        pass

    def _on_step(self) -> bool:
        """
        This method will be called by the model after each call to `env.step()`.

        For child callback (of an `EventCallback`), this will be called
        when the event is triggered.

        :return: (bool) If the callback returns False, training is aborted early.
        """
        return True

    def _on_rollout_end(self) -> None:
        """
        This event is triggered before updating the policy.
        """
        pass

    def _on_training_end(self) -> None:
        """
        This event is triggered before exiting the `learn()` method.
        """
        pass

Note

self.num_timesteps corresponds to the total number of steps taken in the environment, i.e., it is the number of environments multiplied by the number of time env.step() was called

You should know that PPO1 and TRPO update self.num_timesteps after each rollout (and not each step) because they rely on MPI.

For the other algorithms, self.num_timesteps is incremented by n_envs (number of environments) after each call to env.step()

Note

For off-policy algorithms like SAC, DDPG, TD3 or DQN, the notion of rollout corresponds to the steps taken in the environment between two updates.

Event Callback

Compared to Keras, Stable Baselines provides a second type of BaseCallback, named EventCallback that is meant to trigger events. When an event is triggered, then a child callback is called.

As an example, EvalCallback is an EventCallback that will trigger its child callback when there is a new best model. A child callback is for instance StopTrainingOnRewardThreshold that stops the training if the mean reward achieved by the RL model is above a threshold.

Note

We recommend to take a look at the source code of EvalCallback and StopTrainingOnRewardThreshold to have a better overview of what can be achieved with this kind of callbacks.

class EventCallback(BaseCallback):
    """
    Base class for triggering callback on event.

    :param callback: (Optional[BaseCallback]) Callback that will be called
        when an event is triggered.
    :param verbose: (int)
    """
    def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
        super(EventCallback, self).__init__(verbose=verbose)
        self.callback = callback
        # Give access to the parent
        if callback is not None:
            self.callback.parent = self
    ...

    def _on_event(self) -> bool:
        if self.callback is not None:
            return self.callback()
        return True

Callback Collection

Stable Baselines provides you with a set of common callbacks for:

CheckpointCallback

Callback for saving a model every save_freq steps, you must specify a log folder (save_path) and optionally a prefix for the checkpoints (rl_model by default).

from stable_baselines import SAC
from stable_baselines.common.callbacks import CheckpointCallback
# Save a checkpoint every 1000 steps
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/',
                                         name_prefix='rl_model')

model = SAC('MlpPolicy', 'Pendulum-v0')
model.learn(2000, callback=checkpoint_callback)

EvalCallback

Evaluate periodically the performance of an agent, using a separate test environment. It will save the best model if best_model_save_path folder is specified and save the evaluations results in a numpy archive (evaluations.npz) if log_path folder is specified.

Note

You can pass a child callback via the callback_on_new_best argument. It will be triggered each time there is a new best model.

import gym

from stable_baselines import SAC
from stable_baselines.common.callbacks import EvalCallback

# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
# Use deterministic actions for evaluation
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
                             log_path='./logs/', eval_freq=500,
                             deterministic=True, render=False)

model = SAC('MlpPolicy', 'Pendulum-v0')
model.learn(5000, callback=eval_callback)

CallbackList

Class for chaining callbacks, they will be called sequentially. Alternatively, you can pass directly a list of callbacks to the learn() method, it will be converted automatically to a CallbackList.

import gym

from stable_baselines import SAC
from stable_baselines.common.callbacks import CallbackList, CheckpointCallback, EvalCallback

checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')
# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/best_model',
                             log_path='./logs/results', eval_freq=500)
# Create the callback list
callback = CallbackList([checkpoint_callback, eval_callback])

model = SAC('MlpPolicy', 'Pendulum-v0')
# Equivalent to:
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
model.learn(5000, callback=callback)

StopTrainingOnRewardThreshold

Stop the training once a threshold in episodic reward (mean episode reward over the evaluations) has been reached (i.e., when the model is good enough). It must be used with the EvalCallback and use the event triggered by a new best model.

import gym

from stable_baselines import SAC
from stable_baselines.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)

model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the reward threshold is reached
model.learn(int(1e10), callback=eval_callback)

EveryNTimesteps

An Event Callback that will trigger its child callback every n_steps timesteps.

Note

Because of the way PPO1 and TRPO work (they rely on MPI), n_steps is a lower bound between two events.

import gym

from stable_baselines import PPO2
from stable_baselines.common.callbacks import CheckpointCallback, EveryNTimesteps

# this is equivalent to defining CheckpointCallback(save_freq=500)
# checkpoint_callback will be triggered every 500 steps
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs/')
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)

model = PPO2('MlpPolicy', 'Pendulum-v0', verbose=1)

model.learn(int(2e4), callback=event_callback)

Legacy: A functional approach

Warning

This way of doing callbacks is deprecated in favor of the object oriented approach.

A callback function takes the locals() variables and the globals() variables from the model, then returns a boolean value for whether or not the training should continue.

Thanks to the access to the models variables, in particular _locals["self"], we are able to even change the parameters of the model without halting the training, or changing the model’s code.

from typing import Dict, Any

from stable_baselines import PPO2


def simple_callback(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> bool:
    """
    Callback called at each step (for DQN and others) or after n steps (see ACER or PPO2).
    This callback will save the model and stop the training after the first call.

    :param _locals: (Dict[str, Any])
    :param _globals: (Dict[str, Any])
    :return: (bool) If your callback returns False, training is aborted early.
    """
    print("callback called")
    # Save the model
    _locals["self"].save("saved_model")
    # If you want to continue training, the callback must return True.
    # return True # returns True, training continues.
    print("stop training")
    return False # returns False, training stops.

model = PPO2('MlpPolicy', 'CartPole-v1')
model.learn(2000, callback=simple_callback)
class stable_baselines.common.callbacks.BaseCallback(verbose: int = 0)[source]

Base class for callback.

Parameters:verbose – (int)
init_callback(model: BaseRLModel) → None[source]

Initialize the callback by saving references to the RL model and the training environment for convenience.

on_step() → bool[source]

This method will be called by the model after each call to env.step().

For child callback (of an EventCallback), this will be called when the event is triggered.

Returns:(bool) If the callback returns False, training is aborted early.
update_locals(locals_: Dict[str, Any]) → None[source]

Updates the local variables of the training process

For reference to which variables are accessible, check each individual algorithm’s documentation :param locals_: (Dict[str, Any]) current local variables

class stable_baselines.common.callbacks.CallbackList(callbacks: List[stable_baselines.common.callbacks.BaseCallback])[source]

Class for chaining callbacks.

Parameters:callbacks – (List[BaseCallback]) A list of callbacks that will be called sequentially.
class stable_baselines.common.callbacks.CheckpointCallback(save_freq: int, save_path: str, name_prefix='rl_model', verbose=0)[source]

Callback for saving a model every save_freq steps

Parameters:
  • save_freq – (int)
  • save_path – (str) Path to the folder where the model will be saved.
  • name_prefix – (str) Common prefix to the saved models
class stable_baselines.common.callbacks.ConvertCallback(callback, verbose=0)[source]

Convert functional callback (old-style) to object.

Parameters:
  • callback – (Callable)
  • verbose – (int)
class stable_baselines.common.callbacks.EvalCallback(eval_env: Union[gym.core.Env, stable_baselines.common.vec_env.base_vec_env.VecEnv], callback_on_new_best: Optional[stable_baselines.common.callbacks.BaseCallback] = None, n_eval_episodes: int = 5, eval_freq: int = 10000, log_path: str = None, best_model_save_path: str = None, deterministic: bool = True, render: bool = False, verbose: int = 1)[source]

Callback for evaluating an agent.

Parameters:
  • eval_env – (Union[gym.Env, VecEnv]) The environment used for initialization
  • callback_on_new_best – (Optional[BaseCallback]) Callback to trigger when there is a new best model according to the mean_reward
  • n_eval_episodes – (int) The number of episodes to test the agent
  • eval_freq – (int) Evaluate the agent every eval_freq call of the callback.
  • log_path – (str) Path to a folder where the evaluations (evaluations.npz) will be saved. It will be updated at each evaluation.
  • best_model_save_path – (str) Path to a folder where the best model according to performance on the eval env will be saved.
  • deterministic – (bool) Whether the evaluation should use a stochastic or deterministic actions.
  • render – (bool) Whether to render or not the environment during evaluation
  • verbose – (int)
class stable_baselines.common.callbacks.EventCallback(callback: Optional[stable_baselines.common.callbacks.BaseCallback] = None, verbose: int = 0)[source]

Base class for triggering callback on event.

Parameters:
  • callback – (Optional[BaseCallback]) Callback that will be called when an event is triggered.
  • verbose – (int)
init_callback(model: BaseRLModel) → None[source]

Initialize the callback by saving references to the RL model and the training environment for convenience.

class stable_baselines.common.callbacks.EveryNTimesteps(n_steps: int, callback: stable_baselines.common.callbacks.BaseCallback)[source]

Trigger a callback every n_steps timesteps

Parameters:
  • n_steps – (int) Number of timesteps between two trigger.
  • callback – (BaseCallback) Callback that will be called when the event is triggered.
class stable_baselines.common.callbacks.StopTrainingOnRewardThreshold(reward_threshold: float, verbose: int = 0)[source]

Stop the training once a threshold in episodic reward has been reached (i.e. when the model is good enough).

It must be used with the EvalCallback.

Parameters:
  • reward_threshold – (float) Minimum expected reward per episode to stop training.
  • verbose – (int)