A2C

A synchronous, deterministic variant of Asynchronous Advantage Actor Critic (A3C). It uses multiple workers to avoid the use of a replay buffer.

Notes

Can I use?

  • Reccurent policies: ✔️
  • Multi processing: ✔️
  • Gym spaces:
Space Action Observation
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️

Example

Train a A2C agent on CartPole-v1 using 4 processes.

import gym

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import A2C

# multiprocess environment
n_cpu = 4
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])

model = A2C(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")

del model # remove to demonstrate saving and loading

model = A2C.load("a2c_cartpole")

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

Parameters

class stable_baselines.a2c.A2C(policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0.01, max_grad_norm=0.5, learning_rate=0.0007, alpha=0.99, epsilon=1e-05, lr_schedule='linear', verbose=0, tensorboard_log=None, _init_setup_model=True)[source]

The A2C (Advantage Actor Critic) model class, https://arxiv.org/abs/1602.01783

Parameters:
  • policy – (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, …)
  • env – (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
  • gamma – (float) Discount factor
  • n_steps – (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
  • vf_coef – (float) Value function coefficient for the loss calculation
  • ent_coef – (float) Entropy coefficient for the loss caculation
  • max_grad_norm – (float) The maximum value for the gradient clipping
  • learning_rate – (float) The learning rate
  • alpha – (float) RMSProp decay parameter (default: 0.99)
  • epsilon – (float) RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update) (default: 1e-5)
  • lr_schedule – (str) The type of scheduler for the learning rate update (‘linear’, ‘constant’, ‘double_linear_con’, ‘middle_drop’ or ‘double_middle_drop’)
  • verbose – (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
  • tensorboard_log – (str) the log location for tensorboard (if None, no logging)
  • _init_setup_model – (bool) Whether or not to build the network at the creation of the instance (used only for loading)
action_probability(observation, state=None, mask=None)

Get the model’s action probability distribution from an observation

Parameters:
  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
Returns:

(np.ndarray) the model’s action probability distribution

get_env()

returns the current environment (can be None if not defined)

Returns:(Gym Environment) The current environment
learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name='A2C')[source]

Return a trained model.

Parameters:
  • total_timesteps – (int) The total number of samples to train on
  • seed – (int) The initial seed for training, if None: keep current seed
  • callback – (function (dict, dict)) function called at every steps with state of the algorithm. It takes the local and global variables.
  • log_interval – (int) The number of timesteps before logging.
  • tb_log_name – (str) the name of the run for tensorboard log
Returns:

(BaseRLModel) the trained model

classmethod load(load_path, env=None, **kwargs)

Load the model from file

Parameters:
  • load_path – (str) the saved parameter location
  • env – (Gym Envrionment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model)
  • kwargs – extra arguments to change the model when loading
predict(observation, state=None, mask=None, deterministic=False)

Get the model’s action from an observation

Parameters:
  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.
Returns:

(np.ndarray, np.ndarray) the model’s action and the next state (used in recurrent policies)

save(save_path)[source]

Save the current parameters to file

Parameters:save_path – (str) the save location
set_env(env)

Checks the validity of the environment, and if it is coherent, set it as the current environment.

Parameters:env – (Gym Environment) The environment for learning a policy
setup_model()[source]

Create all the functions and tensorflow graphs necessary to train the model