Trust Region Policy Optimization (TRPO) is an iterative approach for optimizing policies with guaranteed monotonic improvement.


Can I use?

  • Reccurent policies: ✔️
  • Multi processing: ✔️ (using MPI)
  • Gym spaces:
Space Action Observation
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️


import gym

from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import TRPO

env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])

model = TRPO(MlpPolicy, env, verbose=1)

del model # remove to demonstrate saving and loading

model = TRPO.load("trpo_cartpole")

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)


class stable_baselines.trpo_mpi.TRPO(policy, env, gamma=0.99, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, lam=0.98, entcoeff=0.0, cg_damping=0.01, vf_stepsize=0.0003, vf_iters=3, verbose=0, tensorboard_log=None, _init_setup_model=True)[source]
action_probability(observation, state=None, mask=None)

Get the model’s action probability distribution from an observation

  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)

(np.ndarray) the model’s action probability distribution


returns the current environment (can be None if not defined)

Returns:(Gym Environment) The current environment
learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name='TRPO')[source]

Return a trained model.

  • total_timesteps – (int) The total number of samples to train on
  • seed – (int) The initial seed for training, if None: keep current seed
  • callback – (function (dict, dict)) -> boolean function called at every steps with state of the algorithm. It takes the local and global variables. If it returns False, training is aborted.
  • log_interval – (int) The number of timesteps before logging.
  • tb_log_name – (str) the name of the run for tensorboard log

(BaseRLModel) the trained model

classmethod load(load_path, env=None, **kwargs)

Load the model from file

  • load_path – (str or file-like) the saved parameter location
  • env – (Gym Envrionment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model)
  • kwargs – extra arguments to change the model when loading
predict(observation, state=None, mask=None, deterministic=False)

Get the model’s action from an observation

  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.

(np.ndarray, np.ndarray) the model’s action and the next state (used in recurrent policies)


Save the current parameters to file

Parameters:save_path – (str or file-like object) the save location

Checks the validity of the environment, and if it is coherent, set it as the current environment.

Parameters:env – (Gym Environment) The environment for learning a policy

Create all the functions and tensorflow graphs necessary to train the model