This package is in maintenance mode, please use Stable-Baselines3 (SB3) for an up-to-date version. You can find a migration guide in SB3 documentation.

Pre-Training (Behavior Cloning)

With the .pretrain() method, you can pre-train RL policies using trajectories from an expert, and therefore accelerate training.

Behavior Cloning (BC) treats the problem of imitation learning, i.e., using expert demonstrations, as a supervised learning problem. That is to say, given expert trajectories (observations-actions pairs), the policy network is trained to reproduce the expert behavior: for a given observation, the action taken by the policy must be the one taken by the expert.

Expert trajectories can be human demonstrations, trajectories from another controller (e.g. a PID controller) or trajectories from a trained RL agent.


Only Box and Discrete spaces are supported for now for pre-training a model.


Images datasets are treated a bit differently as other datasets to avoid memory issues. The images from the expert demonstrations must be located in a folder, not in the expert numpy archive.

Generate Expert Trajectories

Here, we are going to train a RL model and then generate expert trajectories using this agent.

Note that in practice, generating expert trajectories usually does not require training an RL agent.

The following example is only meant to demonstrate the pretrain() feature.

However, we recommend users to take a look at the code of the generate_expert_traj() function (located in gail/dataset/ folder) to learn about the data structure of the expert dataset (see below for an overview) and how to record trajectories.

from stable_baselines import DQN
from stable_baselines.gail import generate_expert_traj

model = DQN('MlpPolicy', 'CartPole-v1', verbose=1)
      # Train a DQN agent for 1e5 timesteps and generate 10 trajectories
      # data will be saved in a numpy archive named `expert_cartpole.npz`
generate_expert_traj(model, 'expert_cartpole', n_timesteps=int(1e5), n_episodes=10)

Here is an additional example when the expert controller is a callable, that is passed to the function instead of a RL model. The idea is that this callable can be a PID controller, asking a human player, …

import gym

from stable_baselines.gail import generate_expert_traj

env = gym.make("CartPole-v1")
# Here the expert is a random agent
# but it can be any python function, e.g. a PID controller
def dummy_expert(_obs):
    Random agent. It samples actions randomly
    from the action space of the environment.

    :param _obs: (np.ndarray) Current observation
    :return: (np.ndarray) action taken by the expert
    return env.action_space.sample()
# Data will be saved in a numpy archive named `expert_cartpole.npz`
# when using something different than an RL expert,
# you must pass the environment object explicitly
generate_expert_traj(dummy_expert, 'dummy_expert_cartpole', env, n_episodes=10)

Pre-Train a Model using Behavior Cloning

Using the expert_cartpole.npz dataset generated with the previous script.

from stable_baselines import PPO2
from stable_baselines.gail import ExpertDataset
# Using only one expert trajectory
# you can specify `traj_limitation=-1` for using the whole dataset
dataset = ExpertDataset(expert_path='expert_cartpole.npz',
                        traj_limitation=1, batch_size=128)

model = PPO2('MlpPolicy', 'CartPole-v1', verbose=1)
# Pretrain the PPO2 model
model.pretrain(dataset, n_epochs=1000)

# As an option, you can train the RL agent
# model.learn(int(1e5))

# Test the pre-trained model
env = model.get_env()
obs = env.reset()

reward_sum = 0.0
for _ in range(1000):
        action, _ = model.predict(obs)
        obs, reward, done, _ = env.step(action)
        reward_sum += reward
        if done:
                reward_sum = 0.0
                obs = env.reset()


Data Structure of the Expert Dataset

The expert dataset is a .npz archive. The data is saved in python dictionary format with keys: actions, episode_returns, rewards, obs, episode_starts.

In case of images, obs contains the relative path to the images.

obs, actions: shape (N * L, ) + S

where N = # episodes, L = episode length and S is the environment observation/action space.

S = (1, ) for discrete space

class stable_baselines.gail.ExpertDataset(expert_path=None, traj_data=None, train_fraction=0.7, batch_size=64, traj_limitation=-1, randomize=True, verbose=1, sequential_preprocessing=False)[source]

Dataset for using behavior cloning or GAIL.

The structure of the expert dataset is a dict, saved as an “.npz” archive. The dictionary contains the keys ‘actions’, ‘episode_returns’, ‘rewards’, ‘obs’ and ‘episode_starts’. The corresponding values have data concatenated across episode: the first axis is the timestep, the remaining axes index into the data. In case of images, ‘obs’ contains the relative path to the images, to enable space saving from image compression.

  • expert_path – (str) The path to trajectory data (.npz file). Mutually exclusive with traj_data.
  • traj_data – (dict) Trajectory data, in format described above. Mutually exclusive with expert_path.
  • train_fraction – (float) the train validation split (0 to 1) for pre-training using behavior cloning (BC)
  • batch_size – (int) the minibatch size for behavior cloning
  • traj_limitation – (int) the number of trajectory to use (if -1, load all)
  • randomize – (bool) if the dataset should be shuffled
  • verbose – (int) Verbosity
  • sequential_preprocessing – (bool) Do not use subprocess to preprocess the data (slower but use less memory for the CI)

Get the batch from the dataset.

Parameters:split – (str) the type of data split (can be None, ‘train’, ‘val’)
Returns:(np.ndarray, np.ndarray) inputs and labels

Initialize the dataloader used by GAIL.

Parameters:batch_size – (int)

Log the information of the dataset.


Show histogram plotting of the episode returns

class stable_baselines.gail.DataLoader(indices, observations, actions, batch_size, n_workers=1, infinite_loop=True, max_queue_len=1, shuffle=False, start_process=True, backend='threading', sequential=False, partial_minibatch=True)[source]

A custom dataloader to preprocessing observations (including images) and feed them to the network.

Original code for the dataloader from https://github.com/araffin/robotics-rl-srl (MIT licence) Authors: Antonin Raffin, René Traoré, Ashley Hill

  • indices – ([int]) list of observations indices
  • observations – (np.ndarray) observations or images path
  • actions – (np.ndarray) actions
  • batch_size – (int) Number of samples per minibatch
  • n_workers – (int) number of preprocessing worker (for loading the images)
  • infinite_loop – (bool) whether to have an iterator that can be reset
  • max_queue_len – (int) Max number of minibatches that can be preprocessed at the same time
  • shuffle – (bool) Shuffle the minibatch after each epoch
  • start_process – (bool) Start the preprocessing process (default: True)
  • backend – (str) joblib backend (one of ‘multiprocessing’, ‘sequential’, ‘threading’ or ‘loky’ in newest versions)
  • sequential – (bool) Do not use subprocess to preprocess the data (slower but use less memory for the CI)
  • partial_minibatch – (bool) Allow partial minibatches (minibatches with a number of element lesser than the batch_size)

Sequential version of the pre-processing.


Start preprocessing process

stable_baselines.gail.generate_expert_traj(model, save_path=None, env=None, n_timesteps=0, n_episodes=100, image_folder='recorded_images')[source]

Train expert controller (if needed) and record expert trajectories.


only Box and Discrete spaces are supported for now.

  • model – (RL model or callable) The expert model, if it needs to be trained, then you need to pass n_timesteps > 0.
  • save_path – (str) Path without the extension where the expert dataset will be saved (ex: ‘expert_cartpole’ -> creates ‘expert_cartpole.npz’). If not specified, it will not save, and just return the generated expert trajectories. This parameter must be specified for image-based environments.
  • env – (gym.Env) The environment, if not defined then it tries to use the model environment.
  • n_timesteps – (int) Number of training timesteps
  • n_episodes – (int) Number of trajectories (episodes) to record
  • image_folder – (str) When using images, folder that will be used to record images.

(dict) the generated expert trajectories.