DQN

Deep Q Network (DQN) and its extensions (Double-DQN, Dueling-DQN, Prioritized Experience Replay).

Warning

The DQN model does not support stable_baselines.common.policies, as a result it must use its own policy models (see DQN Policies).

Available Policies

MlpPolicy Policy object that implements DQN policy, using a MLP (2 layers of 64)
LnMlpPolicy Policy object that implements DQN policy, using a MLP (2 layers of 64), with layer normalisation
CnnPolicy Policy object that implements DQN policy, using a CNN (the nature CNN)
LnCnnPolicy Policy object that implements DQN policy, using a CNN (the nature CNN), with layer normalisation

Notes

Can I use?

  • Reccurent policies: ❌
  • Multi processing: ❌
  • Gym spaces:
Space Action Observation
Discrete ✔️ ✔️
Box ✔️
MultiDiscrete ✔️
MultiBinary ✔️

Example

import gym

from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.deepq.policies import MlpPolicy
from stable_baselines import DQN

env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])

model = DQN(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("deepq_cartpole")

del model # remove to demonstrate saving and loading

model = DQN.load("deepq_cartpole")

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

With Atari:

from stable_baselines.common.atari_wrappers import make_atari
from stable_baselines.deepq.policies import MlpPolicy, CnnPolicy
from stable_baselines import DQN

env = make_atari('BreakoutNoFrameskip-v4')

model = DQN(CnnPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("deepq_breakout")

del model # remove to demonstrate saving and loading

DQN.load("deepq_breakout")

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

Parameters

class stable_baselines.deepq.DQN(policy, env, gamma=0.99, learning_rate=0.0005, buffer_size=50000, exploration_fraction=0.1, exploration_final_eps=0.02, train_freq=1, batch_size=32, checkpoint_freq=10000, checkpoint_path=None, learning_starts=1000, target_network_update_freq=500, prioritized_replay=False, prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None, prioritized_replay_eps=1e-06, param_noise=False, verbose=0, tensorboard_log=None, _init_setup_model=True)[source]

The DQN model class. DQN paper: https://arxiv.org/pdf/1312.5602.pdf

Parameters:
  • policy – (DQNPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, LnMlpPolicy, …)
  • env – (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
  • gamma – (float) discount factor
  • learning_rate – (float) learning rate for adam optimizer
  • buffer_size – (int) size of the replay buffer
  • exploration_fraction – (float) fraction of entire training period over which the exploration rate is annealed
  • exploration_final_eps – (float) final value of random action probability
  • train_freq – (int) update the model every train_freq steps. set to None to disable printing
  • batch_size – (int) size of a batched sampled from replay buffer for training
  • checkpoint_freq – (int) how often to save the model. This is so that the best version is restored at the end of the training. If you do not wish to restore the best version at the end of the training set this variable to None.
  • checkpoint_path – (str) replacement path used if you need to log to somewhere else than a temporary directory.
  • learning_starts – (int) how many steps of the model to collect transitions for before learning starts
  • target_network_update_freq – (int) update the target network every target_network_update_freq steps.
  • prioritized_replay – (bool) if True prioritized replay buffer will be used.
  • prioritized_replay_alpha – (float) alpha parameter for prioritized replay buffer
  • prioritized_replay_beta0 – (float) initial value of beta for prioritized replay buffer
  • prioritized_replay_beta_iters – (int) number of iterations over which beta will be annealed from initial value to 1.0. If set to None equals to max_timesteps.
  • prioritized_replay_eps – (float) epsilon to add to the TD errors when updating priorities.
  • param_noise – (bool) Whether or not to apply noise to the parameters of the policy.
  • verbose – (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
  • tensorboard_log – (str) the log location for tensorboard (if None, no logging)
  • _init_setup_model – (bool) Whether or not to build the network at the creation of the instance
action_probability(observation, state=None, mask=None)[source]

Get the model’s action probability distribution from an observation

Parameters:
  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
Returns:

(np.ndarray) the model’s action probability distribution

get_env()

returns the current environment (can be None if not defined)

Returns:(Gym Environment) The current environment
learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name='DQN')[source]

Return a trained model.

Parameters:
  • total_timesteps – (int) The total number of samples to train on
  • seed – (int) The initial seed for training, if None: keep current seed
  • callback – (function (dict, dict)) function called at every steps with state of the algorithm. It takes the local and global variables.
  • log_interval – (int) The number of timesteps before logging.
  • tb_log_name – (str) the name of the run for tensorboard log
Returns:

(BaseRLModel) the trained model

classmethod load(load_path, env=None, **kwargs)[source]

Load the model from file

Parameters:
  • load_path – (str) the saved parameter location
  • env – (Gym Envrionment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model)
  • kwargs – extra arguments to change the model when loading
predict(observation, state=None, mask=None, deterministic=True)[source]

Get the model’s action from an observation

Parameters:
  • observation – (np.ndarray) the input observation
  • state – (np.ndarray) The last states (can be None, used in recurrent policies)
  • mask – (np.ndarray) The last masks (can be None, used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.
Returns:

(np.ndarray, np.ndarray) the model’s action and the next state (used in recurrent policies)

save(save_path)[source]

Save the current parameters to file

Parameters:save_path – (str) the save location
set_env(env)

Checks the validity of the environment, and if it is coherent, set it as the current environment.

Parameters:env – (Gym Environment) The environment for learning a policy
setup_model()[source]

Create all the functions and tensorflow graphs necessary to train the model

DQN Policies

class stable_baselines.deepq.MlpPolicy(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, obs_phs=None, dueling=True, **_kwargs)[source]

Policy object that implements DQN policy, using a MLP (2 layers of 64)

Parameters:
  • sess – (TensorFlow session) The current TensorFlow session
  • ob_space – (Gym Space) The observation space of the environment
  • ac_space – (Gym Space) The action space of the environment
  • n_env – (int) The number of environments to run
  • n_steps – (int) The number of steps to run for each environment
  • n_batch – (int) The number of batch to run (n_envs * n_steps)
  • reuse – (bool) If the policy is reusable or not
  • obs_phs – (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder and the processed observation placeholder respectivly
  • dueling – (bool) if true double the output MLP to compute a baseline for action scores
  • _kwargs – (dict) Extra keyword arguments for the nature CNN feature extraction
proba_step(obs, state=None, mask=None)

Returns the action probability for a single step

Parameters:
  • obs – (np.ndarray float or int) The current observation of the environment
  • state – (np.ndarray float) The last states (used in recurrent policies)
  • mask – (np.ndarray float) The last masks (used in recurrent policies)
Returns:

(np.ndarray float) the action probability

step(obs, state=None, mask=None, deterministic=True)

Returns the q_values for a single step

Parameters:
  • obs – (np.ndarray float or int) The current observation of the environment
  • state – (np.ndarray float) The last states (used in recurrent policies)
  • mask – (np.ndarray float) The last masks (used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.
Returns:

(np.ndarray int, np.ndarray float, np.ndarray float) actions, q_values, states

class stable_baselines.deepq.LnMlpPolicy(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, obs_phs=None, dueling=True, **_kwargs)[source]

Policy object that implements DQN policy, using a MLP (2 layers of 64), with layer normalisation

Parameters:
  • sess – (TensorFlow session) The current TensorFlow session
  • ob_space – (Gym Space) The observation space of the environment
  • ac_space – (Gym Space) The action space of the environment
  • n_env – (int) The number of environments to run
  • n_steps – (int) The number of steps to run for each environment
  • n_batch – (int) The number of batch to run (n_envs * n_steps)
  • reuse – (bool) If the policy is reusable or not
  • obs_phs – (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder and the processed observation placeholder respectivly
  • dueling – (bool) if true double the output MLP to compute a baseline for action scores
  • _kwargs – (dict) Extra keyword arguments for the nature CNN feature extraction
proba_step(obs, state=None, mask=None)

Returns the action probability for a single step

Parameters:
  • obs – (np.ndarray float or int) The current observation of the environment
  • state – (np.ndarray float) The last states (used in recurrent policies)
  • mask – (np.ndarray float) The last masks (used in recurrent policies)
Returns:

(np.ndarray float) the action probability

step(obs, state=None, mask=None, deterministic=True)

Returns the q_values for a single step

Parameters:
  • obs – (np.ndarray float or int) The current observation of the environment
  • state – (np.ndarray float) The last states (used in recurrent policies)
  • mask – (np.ndarray float) The last masks (used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.
Returns:

(np.ndarray int, np.ndarray float, np.ndarray float) actions, q_values, states

class stable_baselines.deepq.CnnPolicy(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, obs_phs=None, dueling=True, **_kwargs)[source]

Policy object that implements DQN policy, using a CNN (the nature CNN)

Parameters:
  • sess – (TensorFlow session) The current TensorFlow session
  • ob_space – (Gym Space) The observation space of the environment
  • ac_space – (Gym Space) The action space of the environment
  • n_env – (int) The number of environments to run
  • n_steps – (int) The number of steps to run for each environment
  • n_batch – (int) The number of batch to run (n_envs * n_steps)
  • reuse – (bool) If the policy is reusable or not
  • obs_phs – (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder and the processed observation placeholder respectivly
  • dueling – (bool) if true double the output MLP to compute a baseline for action scores
  • _kwargs – (dict) Extra keyword arguments for the nature CNN feature extraction
proba_step(obs, state=None, mask=None)

Returns the action probability for a single step

Parameters:
  • obs – (np.ndarray float or int) The current observation of the environment
  • state – (np.ndarray float) The last states (used in recurrent policies)
  • mask – (np.ndarray float) The last masks (used in recurrent policies)
Returns:

(np.ndarray float) the action probability

step(obs, state=None, mask=None, deterministic=True)

Returns the q_values for a single step

Parameters:
  • obs – (np.ndarray float or int) The current observation of the environment
  • state – (np.ndarray float) The last states (used in recurrent policies)
  • mask – (np.ndarray float) The last masks (used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.
Returns:

(np.ndarray int, np.ndarray float, np.ndarray float) actions, q_values, states

class stable_baselines.deepq.LnCnnPolicy(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, obs_phs=None, dueling=True, **_kwargs)[source]

Policy object that implements DQN policy, using a CNN (the nature CNN), with layer normalisation

Parameters:
  • sess – (TensorFlow session) The current TensorFlow session
  • ob_space – (Gym Space) The observation space of the environment
  • ac_space – (Gym Space) The action space of the environment
  • n_env – (int) The number of environments to run
  • n_steps – (int) The number of steps to run for each environment
  • n_batch – (int) The number of batch to run (n_envs * n_steps)
  • reuse – (bool) If the policy is reusable or not
  • obs_phs – (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder and the processed observation placeholder respectivly
  • dueling – (bool) if true double the output MLP to compute a baseline for action scores
  • _kwargs – (dict) Extra keyword arguments for the nature CNN feature extraction
proba_step(obs, state=None, mask=None)

Returns the action probability for a single step

Parameters:
  • obs – (np.ndarray float or int) The current observation of the environment
  • state – (np.ndarray float) The last states (used in recurrent policies)
  • mask – (np.ndarray float) The last masks (used in recurrent policies)
Returns:

(np.ndarray float) the action probability

step(obs, state=None, mask=None, deterministic=True)

Returns the q_values for a single step

Parameters:
  • obs – (np.ndarray float or int) The current observation of the environment
  • state – (np.ndarray float) The last states (used in recurrent policies)
  • mask – (np.ndarray float) The last masks (used in recurrent policies)
  • deterministic – (bool) Whether or not to return deterministic actions.
Returns:

(np.ndarray int, np.ndarray float, np.ndarray float) actions, q_values, states

Custom Policy Network

Similarly to the example given in the examples page. You can easily define a custom architecture for the policy network:

import gym

from stable_baselines.deepq.policies import FeedForwardPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import DQN

# Custom MLP policy of three layers of size 128 each
class CustomPolicy(FeedForwardPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomPolicy, self).__init__(*args, **kwargs,
                                           layers=[128, 128, 128],
                                           layer_norm=False,
                                           feature_extraction="mlp")

# Create and wrap the environment
env = gym.make('LunarLander-v2')
env = DummyVecEnv([lambda: env])

model = DQN(CustomPolicy, env, verbose=1)
# Train the agent
model.learn(total_timesteps=100000)