Base RL Class

Common interface for all the RL algorithms

class stable_baselines.common.base_class.BaseRLModel(policy, env, verbose=0, *, requires_vec_env, policy_base)[source]

The base RL model

Parameters:
  • policy – (BasePolicy) Policy object
  • env – (Gym environment) The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)
  • verbose – (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
  • requires_vec_env – (bool) Does this model require a vectorized environment
  • policy_base – (BasePolicy) the base policy used by this method
action_probability(observation, state=None, mask=None)[source]

Get the model’s action probability distribution from an observation

Parameters:
  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
Returns:

(np.ndarray) the model’s action probability distribution

get_env()[source]

returns the current environment (can be None if not defined)

Returns:(Gym Environment) The current environment
learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name='run')[source]

Return a trained model.

Parameters:
  • total_timesteps – (int) The total number of samples to train on
  • seed – (int) The initial seed for training, if None: keep current seed
  • callback – (function (dict, dict)) function called at every steps with state of the algorithm. It takes the local and global variables.
  • log_interval – (int) The number of timesteps before logging.
  • tb_log_name – (str) the name of the run for tensorboard log
Returns:

(BaseRLModel) the trained model

classmethod load(load_path, env=None, **kwargs)[source]

Load the model from file

Parameters:
  • load_path – (str) the saved parameter location
  • env – (Gym Envrionment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model)
  • kwargs – extra arguments to change the model when loading
predict(observation, state=None, mask=None, deterministic=False)[source]

Get the model’s action from an observation

Parameters:
  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.
Returns:

(np.ndarray, np.ndarray) the model’s action and the next state (used in recurrent policies)

save(save_path)[source]

Save the current parameters to file

Parameters:save_path – (str) the save location
set_env(env)[source]

Checks the validity of the environment, and if it is coherent, set it as the current environment.

Parameters:env – (Gym Environment) The environment for learning a policy
setup_model()[source]

Create all the functions and tensorflow graphs necessary to train the model