Base RL Class¶
Common interface for all the RL algorithms
-
class
stable_baselines.common.base_class.
BaseRLModel
(policy, env, verbose=0, *, requires_vec_env, policy_base)[source]¶ The base RL model
Parameters: - policy – (BasePolicy) Policy object
- env – (Gym environment) The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)
- verbose – (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
- requires_vec_env – (bool) Does this model require a vectorized environment
- policy_base – (BasePolicy) the base policy used by this method
-
action_probability
(observation, state=None, mask=None)[source]¶ Get the model’s action probability distribution from an observation
Parameters: - observation – (np.ndarray) the input observation
- state – (np.ndarray) The last states (can be None, used in recurrent policies)
- mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
Returns: (np.ndarray) the model’s action probability distribution
-
get_env
()[source]¶ returns the current environment (can be None if not defined)
Returns: (Gym Environment) The current environment
-
learn
(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name='run')[source]¶ Return a trained model.
Parameters: - total_timesteps – (int) The total number of samples to train on
- seed – (int) The initial seed for training, if None: keep current seed
- callback – (function (dict, dict)) -> boolean function called at every steps with state of the algorithm. It takes the local and global variables. If it returns False, training is aborted.
- log_interval – (int) The number of timesteps before logging.
- tb_log_name – (str) the name of the run for tensorboard log
Returns: (BaseRLModel) the trained model
-
classmethod
load
(load_path, env=None, **kwargs)[source]¶ Load the model from file
Parameters: - load_path – (str or file-like) the saved parameter location
- env – (Gym Envrionment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model)
- kwargs – extra arguments to change the model when loading
-
predict
(observation, state=None, mask=None, deterministic=False)[source]¶ Get the model’s action from an observation
Parameters: - observation – (np.ndarray) the input observation
- state – (np.ndarray) The last states (can be None, used in recurrent policies)
- mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
- deterministic – (bool) Whether or not to return deterministic actions.
Returns: (np.ndarray, np.ndarray) the model’s action and the next state (used in recurrent policies)
-
save
(save_path)[source]¶ Save the current parameters to file
Parameters: save_path – (str or file-like object) the save location