GAIL¶
Generative Adversarial Imitation Learning (GAIL)
Notes¶
- Original paper: https://arxiv.org/abs/1606.03476
If you want to train an imitation learning agent¶
Step 1: Download expert data¶
Download the expert data into ./data
, download link
Step 2: Run GAIL¶
Run with single thread:
python -m stable_baselines.gail.run_mujoco
Run with multiple threads:
mpirun -np 16 python -m stable_baselines.gail.run_mujoco
See help (-h
) for more options.
In case you want to run Behavior Cloning (BC)
python -m stable_baselines.gail.behavior_clone
See help (-h
) for more options.
OpenAI Maintainers:
- Yuan-Hong Liao, andrewliao11_at_gmail_dot_com
- Ryan Julian, ryanjulian_at_gmail_dot_com
Others
Thanks to the open source:
- @openai/imitation
- @carpedm20/deep-rl-tensorflow
Can I use?¶
- Reccurent policies: ✔️
- Multi processing: ✔️ (using MPI)
- Gym spaces:
Space | Action | Observation |
---|---|---|
Discrete | ❌ | ✔️ |
Box | ✔️ | ✔️ |
MultiDiscrete | ❌ | ✔️ |
MultiBinary | ❌ | ✔️ |
Parameters¶
-
class
stable_baselines.gail.
GAIL
(policy, env, pretrained_weight=False, hidden_size_adversary=100, adversary_entcoeff=0.001, expert_dataset=None, save_per_iter=1, checkpoint_dir='/tmp/gail/ckpt/', g_step=1, d_step=1, task_name='task_name', d_stepsize=0.0003, verbose=0, _init_setup_model=True, **kwargs)[source]¶ Generative Adversarial Imitation Learning (GAIL)
Parameters: - policy – (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, …)
- env – (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
- gamma – (float) the discount value
- timesteps_per_batch – (int) the number of timesteps to run per batch (horizon)
- max_kl – (float) the kullback leiber loss threashold
- cg_iters – (int) the number of iterations for the conjugate gradient calculation
- lam – (float) GAE factor
- entcoeff – (float) the weight for the entropy loss
- cg_damping – (float) the compute gradient dampening factor
- vf_stepsize – (float) the value function stepsize
- vf_iters – (int) the value function’s number iterations for learning
- pretrained_weight – (str) the save location for the pretrained weights
- hidden_size – ([int]) the hidden dimension for the MLP
- expert_dataset – (Dset) the dataset manager
- save_per_iter – (int) the number of iterations before saving
- checkpoint_dir – (str) the location for saving checkpoints
- g_step – (int) number of steps to train policy in each epoch
- d_step – (int) number of steps to train discriminator in each epoch
- task_name – (str) the name of the task (can be None)
- d_stepsize – (float) the reward giver stepsize
- verbose – (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
- _init_setup_model – (bool) Whether or not to build the network at the creation of the instance
-
action_probability
(observation, state=None, mask=None)[source]¶ Get the model’s action probability distribution from an observation
Parameters: - observation – (np.ndarray) the input observation
- state – (np.ndarray) The last states (can be None, used in recurrent policies)
- mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
Returns: (np.ndarray) the model’s action probability distribution
-
get_env
()¶ returns the current environment (can be None if not defined)
Returns: (Gym Environment) The current environment
-
learn
(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name='GAIL')[source]¶ Return a trained model.
Parameters: - total_timesteps – (int) The total number of samples to train on
- seed – (int) The initial seed for training, if None: keep current seed
- callback – (function (dict, dict)) function called at every steps with state of the algorithm. It takes the local and global variables.
- log_interval – (int) The number of timesteps before logging.
- tb_log_name – (str) the name of the run for tensorboard log
Returns: (BaseRLModel) the trained model
-
classmethod
load
(load_path, env=None, **kwargs)[source]¶ Load the model from file
Parameters: - load_path – (str) the saved parameter location
- env – (Gym Envrionment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model)
- kwargs – extra arguments to change the model when loading
-
predict
(observation, state=None, mask=None, deterministic=False)[source]¶ Get the model’s action from an observation
Parameters: - observation – (np.ndarray) the input observation
- state – (np.ndarray) The last states (can be None, used in recurrent policies)
- mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
- deterministic – (bool) Whether or not to return deterministic actions.
Returns: (np.ndarray, np.ndarray) the model’s action and the next state (used in recurrent policies)
-
save
(save_path)[source]¶ Save the current parameters to file
Parameters: save_path – (str) the save location