from functools import partial
import tensorflow as tf
import numpy as np
import gym
from stable_baselines import logger, deepq
from stable_baselines.common import tf_util, OffPolicyRLModel, SetVerbosity, TensorboardWriter
from stable_baselines.common.vec_env import VecEnv
from stable_baselines.common.schedules import LinearSchedule
from stable_baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from stable_baselines.deepq.policies import DQNPolicy
from stable_baselines.a2c.utils import find_trainable_variables, total_episode_reward_logger
[docs]class DQN(OffPolicyRLModel):
"""
The DQN model class. DQN paper: https://arxiv.org/pdf/1312.5602.pdf
:param policy: (DQNPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, LnMlpPolicy, ...)
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
:param gamma: (float) discount factor
:param learning_rate: (float) learning rate for adam optimizer
:param buffer_size: (int) size of the replay buffer
:param exploration_fraction: (float) fraction of entire training period over which the exploration rate is
annealed
:param exploration_final_eps: (float) final value of random action probability
:param train_freq: (int) update the model every `train_freq` steps. set to None to disable printing
:param batch_size: (int) size of a batched sampled from replay buffer for training
:param checkpoint_freq: (int) how often to save the model. This is so that the best version is restored at the
end of the training. If you do not wish to restore the best version
at the end of the training set this variable to None.
:param checkpoint_path: (str) replacement path used if you need to log to somewhere else than a temporary
directory.
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
:param target_network_update_freq: (int) update the target network every `target_network_update_freq` steps.
:param prioritized_replay: (bool) if True prioritized replay buffer will be used.
:param prioritized_replay_alpha: (float) alpha parameter for prioritized replay buffer
:param prioritized_replay_beta0: (float) initial value of beta for prioritized replay buffer
:param prioritized_replay_beta_iters: (int) number of iterations over which beta will be annealed from initial
value to 1.0. If set to None equals to max_timesteps.
:param prioritized_replay_eps: (float) epsilon to add to the TD errors when updating priorities.
:param param_noise: (bool) Whether or not to apply noise to the parameters of the policy.
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=50000, exploration_fraction=0.1,
exploration_final_eps=0.02, train_freq=1, batch_size=32, checkpoint_freq=10000, checkpoint_path=None,
learning_starts=1000, target_network_update_freq=500, prioritized_replay=False,
prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None,
prioritized_replay_eps=1e-6, param_noise=False, verbose=0, tensorboard_log=None,
_init_setup_model=True):
# TODO: replay_buffer refactoring
super(DQN, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DQNPolicy,
requires_vec_env=False)
self.checkpoint_path = checkpoint_path
self.param_noise = param_noise
self.learning_starts = learning_starts
self.train_freq = train_freq
self.prioritized_replay = prioritized_replay
self.prioritized_replay_eps = prioritized_replay_eps
self.batch_size = batch_size
self.target_network_update_freq = target_network_update_freq
self.checkpoint_freq = checkpoint_freq
self.prioritized_replay_alpha = prioritized_replay_alpha
self.prioritized_replay_beta0 = prioritized_replay_beta0
self.prioritized_replay_beta_iters = prioritized_replay_beta_iters
self.exploration_final_eps = exploration_final_eps
self.exploration_fraction = exploration_fraction
self.buffer_size = buffer_size
self.learning_rate = learning_rate
self.gamma = gamma
self.tensorboard_log = tensorboard_log
self.graph = None
self.sess = None
self._train_step = None
self.step_model = None
self.update_target = None
self.act = None
self.proba_step = None
self.replay_buffer = None
self.beta_schedule = None
self.exploration = None
self.params = None
self.summary = None
self.episode_reward = None
if _init_setup_model:
self.setup_model()
[docs] def setup_model(self):
with SetVerbosity(self.verbose):
assert not isinstance(self.action_space, gym.spaces.Box), \
"Error: DQN cannot output a gym.spaces.Box action space."
# If the policy is wrap in functool.partial (e.g. to disable dueling)
# unwrap it to check the class type
if isinstance(self.policy, partial):
test_policy = self.policy.func
else:
test_policy = self.policy
assert issubclass(test_policy, DQNPolicy), "Error: the input policy for the DQN model must be " \
"an instance of DQNPolicy."
self.graph = tf.Graph()
with self.graph.as_default():
self.sess = tf_util.make_session(graph=self.graph)
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.act, self._train_step, self.update_target, self.step_model = deepq.build_train(
q_func=self.policy,
ob_space=self.observation_space,
ac_space=self.action_space,
optimizer=optimizer,
gamma=self.gamma,
grad_norm_clipping=10,
param_noise=self.param_noise,
sess=self.sess
)
self.proba_step = self.step_model.proba_step
self.params = find_trainable_variables("deepq")
# Initialize the parameters and copy them to the target network.
tf_util.initialize(self.sess)
self.update_target(sess=self.sess)
self.summary = tf.summary.merge_all()
[docs] def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DQN"):
with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer:
self._setup_learn(seed)
# Create the replay buffer
if self.prioritized_replay:
self.replay_buffer = PrioritizedReplayBuffer(self.buffer_size, alpha=self.prioritized_replay_alpha)
if self.prioritized_replay_beta_iters is None:
prioritized_replay_beta_iters = total_timesteps
self.beta_schedule = LinearSchedule(prioritized_replay_beta_iters,
initial_p=self.prioritized_replay_beta0,
final_p=1.0)
else:
self.replay_buffer = ReplayBuffer(self.buffer_size)
self.beta_schedule = None
# Create the schedule for exploration starting from 1.
self.exploration = LinearSchedule(schedule_timesteps=int(self.exploration_fraction * total_timesteps),
initial_p=1.0,
final_p=self.exploration_final_eps)
episode_rewards = [0.0]
obs = self.env.reset()
reset = True
self.episode_reward = np.zeros((1,))
for step in range(total_timesteps):
if callback is not None:
callback(locals(), globals())
# Take action and update exploration to the newest value
kwargs = {}
if not self.param_noise:
update_eps = self.exploration.value(step)
update_param_noise_threshold = 0.
else:
update_eps = 0.
# Compute the threshold such that the KL divergence between perturbed and non-perturbed
# policy is comparable to eps-greedy exploration with eps = exploration.value(t).
# See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
# for detailed explanation.
update_param_noise_threshold = \
-np.log(1. - self.exploration.value(step) +
self.exploration.value(step) / float(self.env.action_space.n))
kwargs['reset'] = reset
kwargs['update_param_noise_threshold'] = update_param_noise_threshold
kwargs['update_param_noise_scale'] = True
with self.sess.as_default():
action = self.act(np.array(obs)[None], update_eps=update_eps, **kwargs)[0]
env_action = action
reset = False
new_obs, rew, done, _ = self.env.step(env_action)
# Store transition in the replay buffer.
self.replay_buffer.add(obs, action, rew, new_obs, float(done))
obs = new_obs
if writer is not None:
ep_rew = np.array([rew]).reshape((1, -1))
ep_done = np.array([done]).reshape((1, -1))
self.episode_reward = total_episode_reward_logger(self.episode_reward, ep_rew, ep_done, writer,
step)
episode_rewards[-1] += rew
if done:
if not isinstance(self.env, VecEnv):
obs = self.env.reset()
episode_rewards.append(0.0)
reset = True
if step > self.learning_starts and step % self.train_freq == 0:
# Minimize the error in Bellman's equation on a batch sampled from replay buffer.
if self.prioritized_replay:
experience = self.replay_buffer.sample(self.batch_size, beta=self.beta_schedule.value(step))
(obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience
else:
obses_t, actions, rewards, obses_tp1, dones = self.replay_buffer.sample(self.batch_size)
weights, batch_idxes = np.ones_like(rewards), None
if writer is not None:
# run loss backprop with summary, but once every 100 steps save the metadata
# (memory, compute time, ...)
if (1 + step) % 100 == 0:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, td_errors = self._train_step(obses_t, actions, rewards, obses_tp1, obses_tp1,
dones, weights, sess=self.sess, options=run_options,
run_metadata=run_metadata)
writer.add_run_metadata(run_metadata, 'step%d' % step)
else:
summary, td_errors = self._train_step(obses_t, actions, rewards, obses_tp1, obses_tp1,
dones, weights, sess=self.sess)
writer.add_summary(summary, step)
else:
_, td_errors = self._train_step(obses_t, actions, rewards, obses_tp1, obses_tp1, dones, weights,
sess=self.sess)
if self.prioritized_replay:
new_priorities = np.abs(td_errors) + self.prioritized_replay_eps
self.replay_buffer.update_priorities(batch_idxes, new_priorities)
if step > self.learning_starts and step % self.target_network_update_freq == 0:
# Update target network periodically.
self.update_target(sess=self.sess)
if len(episode_rewards[-101:-1]) == 0:
mean_100ep_reward = -np.inf
else:
mean_100ep_reward = round(float(np.mean(episode_rewards[-101:-1])), 1)
num_episodes = len(episode_rewards)
if self.verbose >= 1 and done and log_interval is not None and len(episode_rewards) % log_interval == 0:
logger.record_tabular("steps", step)
logger.record_tabular("episodes", num_episodes)
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
logger.record_tabular("% time spent exploring", int(100 * self.exploration.value(step)))
logger.dump_tabular()
return self
[docs] def predict(self, observation, state=None, mask=None, deterministic=True):
observation = np.array(observation)
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
observation = observation.reshape((-1,) + self.observation_space.shape)
with self.sess.as_default():
actions, _, _ = self.step_model.step(observation, deterministic=deterministic)
if not vectorized_env:
actions = actions[0]
return actions, None
[docs] def action_probability(self, observation, state=None, mask=None):
observation = np.array(observation)
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
observation = observation.reshape((-1,) + self.observation_space.shape)
actions_proba = self.proba_step(observation, state, mask)
if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
actions_proba = actions_proba[0]
return actions_proba
[docs] def save(self, save_path):
# params
data = {
"checkpoint_path": self.checkpoint_path,
"param_noise": self.param_noise,
"learning_starts": self.learning_starts,
"train_freq": self.train_freq,
"prioritized_replay": self.prioritized_replay,
"prioritized_replay_eps": self.prioritized_replay_eps,
"batch_size": self.batch_size,
"target_network_update_freq": self.target_network_update_freq,
"checkpoint_freq": self.checkpoint_freq,
"prioritized_replay_alpha": self.prioritized_replay_alpha,
"prioritized_replay_beta0": self.prioritized_replay_beta0,
"prioritized_replay_beta_iters": self.prioritized_replay_beta_iters,
"exploration_final_eps": self.exploration_final_eps,
"exploration_fraction": self.exploration_fraction,
"learning_rate": self.learning_rate,
"gamma": self.gamma,
"verbose": self.verbose,
"observation_space": self.observation_space,
"action_space": self.action_space,
"policy": self.policy,
"n_envs": self.n_envs,
"_vectorize_action": self._vectorize_action
}
params = self.sess.run(self.params)
self._save_to_file(save_path, data=data, params=params)
[docs] @classmethod
def load(cls, load_path, env=None, **kwargs):
data, params = cls._load_from_file(load_path)
model = cls(policy=data["policy"], env=env, _init_setup_model=False)
model.__dict__.update(data)
model.__dict__.update(kwargs)
model.set_env(env)
model.setup_model()
restores = []
for param, loaded_p in zip(model.params, params):
restores.append(param.assign(loaded_p))
model.sess.run(restores)
return model