Generative Adversarial Imitation Learning (GAIL)


If you want to train an imitation learning agent

Step 1: Download expert data

Download the expert data into ./data, download link

Step 2: Run GAIL

Run with single thread:

python -m stable_baselines.gail.run_mujoco

Run with multiple threads:

mpirun -np 16 python -m stable_baselines.gail.run_mujoco

See help (-h) for more options.

In case you want to run Behavior Cloning (BC)

python -m stable_baselines.gail.behavior_clone

See help (-h) for more options.

OpenAI Maintainers:

  • Yuan-Hong Liao, andrewliao11_at_gmail_dot_com
  • Ryan Julian, ryanjulian_at_gmail_dot_com


Thanks to the open source:

  • @openai/imitation
  • @carpedm20/deep-rl-tensorflow

Can I use?

  • Reccurent policies: ✔️
  • Multi processing: ✔️ (using MPI)
  • Gym spaces:
Space Action Observation
Discrete ✔️
Box ✔️ ✔️
MultiDiscrete ✔️
MultiBinary ✔️


class stable_baselines.gail.GAIL(policy, env, pretrained_weight=False, hidden_size_adversary=100, adversary_entcoeff=0.001, expert_dataset=None, save_per_iter=1, checkpoint_dir='/tmp/gail/ckpt/', g_step=1, d_step=1, task_name='task_name', d_stepsize=0.0003, verbose=0, _init_setup_model=True, **kwargs)[source]

Generative Adversarial Imitation Learning (GAIL)

  • policy – (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, …)
  • env – (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
  • gamma – (float) the discount value
  • timesteps_per_batch – (int) the number of timesteps to run per batch (horizon)
  • max_kl – (float) the kullback leiber loss threashold
  • cg_iters – (int) the number of iterations for the conjugate gradient calculation
  • lam – (float) GAE factor
  • entcoeff – (float) the weight for the entropy loss
  • cg_damping – (float) the compute gradient dampening factor
  • vf_stepsize – (float) the value function stepsize
  • vf_iters – (int) the value function’s number iterations for learning
  • pretrained_weight – (str) the save location for the pretrained weights
  • hidden_size – ([int]) the hidden dimension for the MLP
  • expert_dataset – (Dset) the dataset manager
  • save_per_iter – (int) the number of iterations before saving
  • checkpoint_dir – (str) the location for saving checkpoints
  • g_step – (int) number of steps to train policy in each epoch
  • d_step – (int) number of steps to train discriminator in each epoch
  • task_name – (str) the name of the task (can be None)
  • d_stepsize – (float) the reward giver stepsize
  • verbose – (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
  • _init_setup_model – (bool) Whether or not to build the network at the creation of the instance
action_probability(observation, state=None, mask=None)[source]

Get the model’s action probability distribution from an observation

  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)

(np.ndarray) the model’s action probability distribution


returns the current environment (can be None if not defined)

Returns:(Gym Environment) The current environment
learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name='GAIL')[source]

Return a trained model.

  • total_timesteps – (int) The total number of samples to train on
  • seed – (int) The initial seed for training, if None: keep current seed
  • callback – (function (dict, dict)) function called at every steps with state of the algorithm. It takes the local and global variables.
  • log_interval – (int) The number of timesteps before logging.
  • tb_log_name – (str) the name of the run for tensorboard log

(BaseRLModel) the trained model

classmethod load(load_path, env=None, **kwargs)[source]

Load the model from file

  • load_path – (str) the saved parameter location
  • env – (Gym Envrionment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model)
  • kwargs – extra arguments to change the model when loading
predict(observation, state=None, mask=None, deterministic=False)[source]

Get the model’s action from an observation

  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.

(np.ndarray, np.ndarray) the model’s action and the next state (used in recurrent policies)


Save the current parameters to file

Parameters:save_path – (str) the save location

Checks the validity of the environment, and if it is coherent, set it as the current environment.

Parameters:env – (Gym Environment) The environment for learning a policy

Create all the functions and tensorflow graphs necessary to train the model