ACER

Sample Efficient Actor-Critic with Experience Replay (ACER) combines several ideas of previous algorithms: it uses multiple workers (as A2C), implements a replay buffer (as in DQN), uses Retrace for Q-value estimation, importance sampling and a trust region.

Notes

  • Original paper: https://arxiv.org/abs/1611.01224
  • python -m stable_baselines.acer.run_atari runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (-h) for more options.

Can I use?

  • Recurrent policies: ✔️
  • Multi processing: ✔️
  • Gym spaces:
Space Action Observation
Discrete ✔️ ✔️
Box ✔️
MultiDiscrete ✔️
MultiBinary ✔️

Example

import gym

from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common import make_vec_env
from stable_baselines import ACER

# multiprocess environment
env = make_vec_env('CartPole-v1', n_envs=4)

model = ACER(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("acer_cartpole")

del model # remove to demonstrate saving and loading

model = ACER.load("acer_cartpole")

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

Parameters

class stable_baselines.acer.ACER(policy, env, gamma=0.99, n_steps=20, num_procs=None, q_coef=0.5, ent_coef=0.01, max_grad_norm=10, learning_rate=0.0007, lr_schedule='linear', rprop_alpha=0.99, rprop_epsilon=1e-05, buffer_size=5000, replay_ratio=4, replay_start=1000, correction_term=10.0, trust_region=True, alpha=0.99, delta=1, verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False, seed=None, n_cpu_tf_sess=1)[source]

The ACER (Actor-Critic with Experience Replay) model class, https://arxiv.org/abs/1611.01224

Parameters:
  • policy – (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, …)
  • env – (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
  • gamma – (float) The discount value
  • n_steps – (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
  • num_procs

    (int) The number of threads for TensorFlow operations

    Deprecated since version 2.9.0: Use n_cpu_tf_sess instead.

  • q_coef – (float) The weight for the loss on the Q value
  • ent_coef – (float) The weight for the entropy loss
  • max_grad_norm – (float) The clipping value for the maximum gradient
  • learning_rate – (float) The initial learning rate for the RMS prop optimizer
  • lr_schedule – (str) The type of scheduler for the learning rate update (‘linear’, ‘constant’, ‘double_linear_con’, ‘middle_drop’ or ‘double_middle_drop’)
  • rprop_epsilon – (float) RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update) (default: 1e-5)
  • rprop_alpha – (float) RMSProp decay parameter (default: 0.99)
  • buffer_size – (int) The buffer size in number of steps
  • replay_ratio – (float) The number of replay learning per on policy learning on average, using a poisson distribution
  • replay_start – (int) The minimum number of steps in the buffer, before learning replay
  • correction_term – (float) Importance weight clipping factor (default: 10)
  • trust_region – (bool) Whether or not algorithms estimates the gradient KL divergence between the old and updated policy and uses it to determine step size (default: True)
  • alpha – (float) The decay rate for the Exponential moving average of the parameters
  • delta – (float) max KL divergence between the old policy and updated policy (default: 1)
  • verbose – (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
  • tensorboard_log – (str) the log location for tensorboard (if None, no logging)
  • _init_setup_model – (bool) Whether or not to build the network at the creation of the instance
  • policy_kwargs – (dict) additional arguments to be passed to the policy on creation
  • full_tensorboard_log – (bool) enable additional logging when using tensorboard WARNING: this logging can take a lot of space quickly
  • seed – (int) Seed for the pseudo-random generators (python, numpy, tensorflow). If None (default), use random seed. Note that if you want completely deterministic results, you must set n_cpu_tf_sess to 1.
  • n_cpu_tf_sess – (int) The number of threads for TensorFlow operations If None, the number of cpu of the current machine will be used.
action_probability(observation, state=None, mask=None, actions=None, logp=False)

If actions is None, then get the model’s action probability distribution from a given observation.

Depending on the action space the output is:
  • Discrete: probability for each possible action
  • Box: mean and standard deviation of the action output

However if actions is not None, this function will return the probability that the given actions are taken with the given parameters (observation, state, …) on this model. For discrete action spaces, it returns the probability mass; for continuous action spaces, the probability density. This is since the probability mass will always be zero in continuous spaces, see http://blog.christianperone.com/2019/01/ for a good explanation

Parameters:
  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
  • actions – (np.ndarray) (OPTIONAL) For calculating the likelihood that the given actions are chosen by the model for each of the given parameters. Must have the same number of actions and observations. (set to None to return the complete action probability distribution)
  • logp – (bool) (OPTIONAL) When specified with actions, returns probability in log-space. This has no effect if actions is None.
Returns:

(np.ndarray) the model’s (log) action probability

get_env()

returns the current environment (can be None if not defined)

Returns:(Gym Environment) The current environment
get_parameter_list()

Get tensorflow Variables of model’s parameters

This includes all variables necessary for continuing training (saving / loading).

Returns:(list) List of tensorflow Variables
get_parameters()

Get current model parameters as dictionary of variable name -> ndarray.

Returns:(OrderedDict) Dictionary of variable name -> ndarray of model’s parameters.
learn(total_timesteps, callback=None, log_interval=100, tb_log_name='ACER', reset_num_timesteps=True)[source]

Return a trained model.

Parameters:
  • total_timesteps – (int) The total number of samples to train on
  • callback – (function (dict, dict)) -> boolean function called at every steps with state of the algorithm. It takes the local and global variables. If it returns False, training is aborted.
  • log_interval – (int) The number of timesteps before logging.
  • tb_log_name – (str) the name of the run for tensorboard log
  • reset_num_timesteps – (bool) whether or not to reset the current timestep number (used in logging)
Returns:

(BaseRLModel) the trained model

classmethod load(load_path, env=None, custom_objects=None, **kwargs)

Load the model from file

Parameters:
  • load_path – (str or file-like) the saved parameter location
  • env – (Gym Environment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model)
  • custom_objects – (dict) Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in keras.models.load_model. Useful when you have an object in file that can not be deserialized.
  • kwargs – extra arguments to change the model when loading
load_parameters(load_path_or_dict, exact_match=True)

Load model parameters from a file or a dictionary

Dictionary keys should be tensorflow variable names, which can be obtained with get_parameters function. If exact_match is True, dictionary should contain keys for all model’s parameters, otherwise RunTimeError is raised. If False, only variables included in the dictionary will be updated.

This does not load agent’s hyper-parameters.

Warning

This function does not update trainer/optimizer variables (e.g. momentum). As such training after using this function may lead to less-than-optimal results.

Parameters:
  • load_path_or_dict – (str or file-like or dict) Save parameter location or dict of parameters as variable.name -> ndarrays to be loaded.
  • exact_match – (bool) If True, expects load dictionary to contain keys for all variables in the model. If False, loads parameters only for variables mentioned in the dictionary. Defaults to True.
predict(observation, state=None, mask=None, deterministic=False)

Get the model’s action from an observation

Parameters:
  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.
Returns:

(np.ndarray, np.ndarray) the model’s action and the next state (used in recurrent policies)

pretrain(dataset, n_epochs=10, learning_rate=0.0001, adam_epsilon=1e-08, val_interval=None)

Pretrain a model using behavior cloning: supervised learning given an expert dataset.

NOTE: only Box and Discrete spaces are supported for now.

Parameters:
  • dataset – (ExpertDataset) Dataset manager
  • n_epochs – (int) Number of iterations on the training set
  • learning_rate – (float) Learning rate
  • adam_epsilon – (float) the epsilon value for the adam optimizer
  • val_interval – (int) Report training and validation losses every n epochs. By default, every 10th of the maximum number of epochs.
Returns:

(BaseRLModel) the pretrained model

save(save_path, cloudpickle=False)[source]

Save the current parameters to file

Parameters:
  • save_path – (str or file-like) The save location
  • cloudpickle – (bool) Use older cloudpickle format instead of zip-archives.
set_env(env)[source]

Checks the validity of the environment, and if it is coherent, set it as the current environment.

Parameters:env – (Gym Environment) The environment for learning a policy
set_random_seed(seed)
Parameters:seed – (int) Seed for the pseudo-random generators. If None, do not change the seeds.
setup_model()[source]

Create all the functions and tensorflow graphs necessary to train the model