Source code for stable_baselines.ppo2.ppo2

import time
import sys
import multiprocessing
from collections import deque

import gym
import numpy as np
import tensorflow as tf

from stable_baselines import logger
from stable_baselines.common import explained_variance, ActorCriticRLModel, tf_util, SetVerbosity, TensorboardWriter
from stable_baselines.common.runners import AbstractEnvRunner
from stable_baselines.common.policies import LstmPolicy, ActorCriticPolicy
from stable_baselines.a2c.utils import total_episode_reward_logger


[docs]class PPO2(ActorCriticRLModel): """ Proximal Policy Optimization algorithm (GPU version). Paper: https://arxiv.org/abs/1707.06347 :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...) :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) :param gamma: (float) Discount factor :param n_steps: (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param ent_coef: (float) Entropy coefficient for the loss caculation :param learning_rate: (float or callable) The learning rate, it can be a function :param vf_coef: (float) Value function coefficient for the loss calculation :param max_grad_norm: (float) The maximum value for the gradient clipping :param lam: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param nminibatches: (int) Number of training minibatches per update. For recurrent policies, the number of environments run in parallel should be a multiple of nminibatches. :param noptepochs: (int) Number of epoch when optimizing the surrogate :param cliprange: (float or callable) Clipping parameter, it can be a function :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ def __init__(self, policy, env, gamma=0.99, n_steps=128, ent_coef=0.01, learning_rate=2.5e-4, vf_coef=0.5, max_grad_norm=0.5, lam=0.95, nminibatches=4, noptepochs=4, cliprange=0.2, verbose=0, tensorboard_log=None, _init_setup_model=True): super(PPO2, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True, _init_setup_model=_init_setup_model) if isinstance(learning_rate, float): learning_rate = constfn(learning_rate) else: assert callable(learning_rate) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) self.learning_rate = learning_rate self.cliprange = cliprange self.n_steps = n_steps self.ent_coef = ent_coef self.vf_coef = vf_coef self.max_grad_norm = max_grad_norm self.gamma = gamma self.lam = lam self.nminibatches = nminibatches self.noptepochs = noptepochs self.tensorboard_log = tensorboard_log self.graph = None self.sess = None self.action_ph = None self.advs_ph = None self.rewards_ph = None self.old_neglog_pac_ph = None self.old_vpred_ph = None self.learning_rate_ph = None self.clip_range_ph = None self.entropy = None self.vf_loss = None self.pg_loss = None self.approxkl = None self.clipfrac = None self.params = None self._train = None self.loss_names = None self.train_model = None self.act_model = None self.step = None self.proba_step = None self.value = None self.initial_state = None self.n_batch = None self.summary = None self.episode_reward = None if _init_setup_model: self.setup_model()
[docs] def setup_model(self): with SetVerbosity(self.verbose): assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the PPO2 model must be " \ "an instance of common.policies.ActorCriticPolicy." self.n_batch = self.n_envs * self.n_steps n_cpu = multiprocessing.cpu_count() if sys.platform == 'darwin': n_cpu //= 2 self.graph = tf.Graph() with self.graph.as_default(): self.sess = tf_util.make_session(num_cpu=n_cpu, graph=self.graph) n_batch_step = None n_batch_train = None if issubclass(self.policy, LstmPolicy): assert self.n_envs % self.nminibatches == 0, "For recurrent policies, "\ "the number of environments run in parallel should be a multiple of nminibatches." n_batch_step = self.n_envs n_batch_train = self.n_batch // self.nminibatches act_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1, n_batch_step, reuse=False) with tf.variable_scope("train_model", reuse=True, custom_getter=tf_util.outer_scope_getter("train_model")): train_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs // self.nminibatches, self.n_steps, n_batch_train, reuse=True) with tf.variable_scope("loss", reuse=False): self.action_ph = train_model.pdtype.sample_placeholder([None], name="action_ph") self.advs_ph = tf.placeholder(tf.float32, [None], name="advs_ph") self.rewards_ph = tf.placeholder(tf.float32, [None], name="rewards_ph") self.old_neglog_pac_ph = tf.placeholder(tf.float32, [None], name="old_neglog_pac_ph") self.old_vpred_ph = tf.placeholder(tf.float32, [None], name="old_vpred_ph") self.learning_rate_ph = tf.placeholder(tf.float32, [], name="learning_rate_ph") self.clip_range_ph = tf.placeholder(tf.float32, [], name="clip_range_ph") neglogpac = train_model.proba_distribution.neglogp(self.action_ph) self.entropy = tf.reduce_mean(train_model.proba_distribution.entropy()) vpred = train_model._value vpredclipped = self.old_vpred_ph + tf.clip_by_value( train_model._value - self.old_vpred_ph, - self.clip_range_ph, self.clip_range_ph) vf_losses1 = tf.square(vpred - self.rewards_ph) vf_losses2 = tf.square(vpredclipped - self.rewards_ph) self.vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2)) ratio = tf.exp(self.old_neglog_pac_ph - neglogpac) pg_losses = -self.advs_ph * ratio pg_losses2 = -self.advs_ph * tf.clip_by_value(ratio, 1.0 - self.clip_range_ph, 1.0 + self.clip_range_ph) self.pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2)) self.approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.old_neglog_pac_ph)) self.clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), self.clip_range_ph))) loss = self.pg_loss - self.entropy * self.ent_coef + self.vf_loss * self.vf_coef tf.summary.scalar('entropy_loss', self.entropy) tf.summary.scalar('policy_gradient_loss', self.pg_loss) tf.summary.scalar('value_function_loss', self.vf_loss) tf.summary.scalar('approximate_kullback-leiber', self.approxkl) tf.summary.scalar('clip_factor', self.clipfrac) tf.summary.scalar('loss', loss) with tf.variable_scope('model'): self.params = tf.trainable_variables() grads = tf.gradients(loss, self.params) if self.max_grad_norm is not None: grads, _grad_norm = tf.clip_by_global_norm(grads, self.max_grad_norm) grads = list(zip(grads, self.params)) trainer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_ph, epsilon=1e-5) self._train = trainer.apply_gradients(grads) self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac'] with tf.variable_scope("input_info", reuse=False): tf.summary.scalar('discounted_rewards', tf.reduce_mean(self.rewards_ph)) tf.summary.histogram('discounted_rewards', self.rewards_ph) tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate_ph)) tf.summary.histogram('learning_rate', self.learning_rate_ph) tf.summary.scalar('advantage', tf.reduce_mean(self.advs_ph)) tf.summary.histogram('advantage', self.advs_ph) tf.summary.scalar('clip_range', tf.reduce_mean(self.clip_range_ph)) tf.summary.histogram('clip_range', self.clip_range_ph) tf.summary.scalar('old_neglog_action_probabilty', tf.reduce_mean(self.old_neglog_pac_ph)) tf.summary.histogram('old_neglog_action_probabilty', self.old_neglog_pac_ph) tf.summary.scalar('old_value_pred', tf.reduce_mean(self.old_vpred_ph)) tf.summary.histogram('old_value_pred', self.old_vpred_ph) if len(self.observation_space.shape) == 3: tf.summary.image('observation', train_model.obs_ph) else: tf.summary.histogram('observation', train_model.obs_ph) self.train_model = train_model self.act_model = act_model self.step = act_model.step self.proba_step = act_model.proba_step self.value = act_model.value self.initial_state = act_model.initial_state tf.global_variables_initializer().run(session=self.sess) # pylint: disable=E1101 self.summary = tf.summary.merge_all()
def _train_step(self, learning_rate, cliprange, obs, returns, masks, actions, values, neglogpacs, update, writer, states=None): """ Training of PPO2 Algorithm :param learning_rate: (float) learning rate :param cliprange: (float) Clipping factor :param obs: (np.ndarray) The current observation of the environment :param returns: (np.ndarray) the rewards :param masks: (np.ndarray) The last masks for done episodes (used in recurent policies) :param actions: (np.ndarray) the actions :param values: (np.ndarray) the values :param neglogpacs: (np.ndarray) Negative Log-likelihood probability of Actions :param update: (int) the current step iteration :param writer: (TensorFlow Summary.writer) the writer for tensorboard :param states: (np.ndarray) For recurrent policies, the internal state of the recurrent model :return: policy gradient loss, value function loss, policy entropy, approximation of kl divergence, updated clipping range, training update operation """ advs = returns - values advs = (advs - advs.mean()) / (advs.std() + 1e-8) td_map = {self.train_model.obs_ph: obs, self.action_ph: actions, self.advs_ph: advs, self.rewards_ph: returns, self.learning_rate_ph: learning_rate, self.clip_range_ph: cliprange, self.old_neglog_pac_ph: neglogpacs, self.old_vpred_ph: values} if states is not None: td_map[self.train_model.states_ph] = states td_map[self.train_model.masks_ph] = masks if states is None: update_fac = self.n_batch // self.nminibatches // self.noptepochs else: update_fac = self.n_batch // self.nminibatches // self.noptepochs // self.n_steps if writer is not None: # run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...) if (1 + update) % 10 == 0: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, policy_loss, value_loss, policy_entropy, approxkl, clipfrac, _ = self.sess.run( [self.summary, self.pg_loss, self.vf_loss, self.entropy, self.approxkl, self.clipfrac, self._train], td_map, options=run_options, run_metadata=run_metadata) writer.add_run_metadata(run_metadata, 'step%d' % (update * update_fac)) else: summary, policy_loss, value_loss, policy_entropy, approxkl, clipfrac, _ = self.sess.run( [self.summary, self.pg_loss, self.vf_loss, self.entropy, self.approxkl, self.clipfrac, self._train], td_map) writer.add_summary(summary, (update * update_fac)) else: policy_loss, value_loss, policy_entropy, approxkl, clipfrac, _ = self.sess.run( [self.pg_loss, self.vf_loss, self.entropy, self.approxkl, self.clipfrac, self._train], td_map) return policy_loss, value_loss, policy_entropy, approxkl, clipfrac
[docs] def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_log_name="PPO2"): with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: self._setup_learn(seed) runner = Runner(env=self.env, model=self, n_steps=self.n_steps, gamma=self.gamma, lam=self.lam) self.episode_reward = np.zeros((self.n_envs,)) ep_info_buf = deque(maxlen=100) t_first_start = time.time() nupdates = total_timesteps // self.n_batch for update in range(1, nupdates + 1): assert self.n_batch % self.nminibatches == 0 n_batch_train = self.n_batch // self.nminibatches t_start = time.time() frac = 1.0 - (update - 1.0) / nupdates lr_now = self.learning_rate(frac) cliprangenow = self.cliprange(frac) # true_reward is the reward without discount obs, returns, masks, actions, values, neglogpacs, states, ep_infos, true_reward = runner.run() ep_info_buf.extend(ep_infos) mb_loss_vals = [] if states is None: # nonrecurrent version inds = np.arange(self.n_batch) for epoch_num in range(self.noptepochs): np.random.shuffle(inds) for start in range(0, self.n_batch, n_batch_train): timestep = ((update * self.noptepochs * self.n_batch + epoch_num * self.n_batch + start) // n_batch_train) end = start + n_batch_train mbinds = inds[start:end] slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs)) mb_loss_vals.append(self._train_step(lr_now, cliprangenow, *slices, writer=writer, update=timestep)) else: # recurrent version assert self.n_envs % self.nminibatches == 0 envinds = np.arange(self.n_envs) flatinds = np.arange(self.n_envs * self.n_steps).reshape(self.n_envs, self.n_steps) envsperbatch = n_batch_train // self.n_steps for epoch_num in range(self.noptepochs): np.random.shuffle(envinds) for start in range(0, self.n_envs, envsperbatch): timestep = ((update * self.noptepochs * self.n_envs + epoch_num * self.n_envs + start) // envsperbatch) end = start + envsperbatch mb_env_inds = envinds[start:end] mb_flat_inds = flatinds[mb_env_inds].ravel() slices = (arr[mb_flat_inds] for arr in (obs, returns, masks, actions, values, neglogpacs)) mb_states = states[mb_env_inds] mb_loss_vals.append(self._train_step(lr_now, cliprangenow, *slices, update=timestep, writer=writer, states=mb_states)) loss_vals = np.mean(mb_loss_vals, axis=0) t_now = time.time() fps = int(self.n_batch / (t_now - t_start)) if writer is not None: self.episode_reward = total_episode_reward_logger(self.episode_reward, true_reward.reshape((self.n_envs, self.n_steps)), masks.reshape((self.n_envs, self.n_steps)), writer, update * (self.n_batch + 1)) if callback is not None: callback(locals(), globals()) if self.verbose >= 1 and (update % log_interval == 0 or update == 1): explained_var = explained_variance(values, returns) logger.logkv("serial_timesteps", (update + 1) * self.n_steps) logger.logkv("nupdates", (update + 1)) logger.logkv("total_timesteps", (update + 1) * self.n_batch) logger.logkv("fps", fps) logger.logkv("explained_variance", float(explained_var)) logger.logkv('ep_rewmean', safe_mean([ep_info['r'] for ep_info in ep_info_buf])) logger.logkv('eplenmean', safe_mean([ep_info['l'] for ep_info in ep_info_buf])) logger.logkv('time_elapsed', t_start - t_first_start) for (loss_val, loss_name) in zip(loss_vals, self.loss_names): logger.logkv(loss_name, loss_val) logger.dumpkvs() return self
[docs] def save(self, save_path): data = { "gamma": self.gamma, "n_steps": self.n_steps, "vf_coef": self.vf_coef, "ent_coef": self.ent_coef, "max_grad_norm": self.max_grad_norm, "learning_rate": self.learning_rate, "lam": self.lam, "nminibatches": self.nminibatches, "noptepochs": self.noptepochs, "cliprange": self.cliprange, "verbose": self.verbose, "policy": self.policy, "observation_space": self.observation_space, "action_space": self.action_space, "n_envs": self.n_envs, "_vectorize_action": self._vectorize_action } params = self.sess.run(self.params) self._save_to_file(save_path, data=data, params=params)
class Runner(AbstractEnvRunner): def __init__(self, *, env, model, n_steps, gamma, lam): """ A runner to learn the policy of an environment for a model :param env: (Gym environment) The environment to learn from :param model: (Model) The model to learn :param n_steps: (int) The number of steps to run for each environment :param gamma: (float) Discount factor :param lam: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator """ super().__init__(env=env, model=model, n_steps=n_steps) self.lam = lam self.gamma = gamma def run(self): """ Run a learning step of the model :return: - observations: (np.ndarray) the observations - rewards: (np.ndarray) the rewards - masks: (numpy bool) whether an episode is over or not - actions: (np.ndarray) the actions - values: (np.ndarray) the value function output - negative log probabilities: (np.ndarray) - states: (np.ndarray) the internal states of the recurrent policies - infos: (dict) the extra information of the model """ # mb stands for minibatch mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [], [], [], [], [], [] mb_states = self.states ep_infos = [] for _ in range(self.n_steps): actions, values, self.states, neglogpacs = self.model.step(self.obs, self.states, self.dones) mb_obs.append(self.obs.copy()) mb_actions.append(actions) mb_values.append(values) mb_neglogpacs.append(neglogpacs) mb_dones.append(self.dones) clipped_actions = actions # Clip the actions to avoid out of bound error if isinstance(self.env.action_space, gym.spaces.Box): clipped_actions = np.clip(actions, self.env.action_space.low, self.env.action_space.high) self.obs[:], rewards, self.dones, infos = self.env.step(clipped_actions) for info in infos: maybeep_info = info.get('episode') if maybeep_info: ep_infos.append(maybeep_info) mb_rewards.append(rewards) # batch of steps to batch of rollouts mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype) mb_rewards = np.asarray(mb_rewards, dtype=np.float32) mb_actions = np.asarray(mb_actions) mb_values = np.asarray(mb_values, dtype=np.float32) mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32) mb_dones = np.asarray(mb_dones, dtype=np.bool) last_values = self.model.value(self.obs, self.states, self.dones) # discount/bootstrap off value fn mb_advs = np.zeros_like(mb_rewards) true_reward = np.copy(mb_rewards) last_gae_lam = 0 for step in reversed(range(self.n_steps)): if step == self.n_steps - 1: nextnonterminal = 1.0 - self.dones nextvalues = last_values else: nextnonterminal = 1.0 - mb_dones[step + 1] nextvalues = mb_values[step + 1] delta = mb_rewards[step] + self.gamma * nextvalues * nextnonterminal - mb_values[step] mb_advs[step] = last_gae_lam = delta + self.gamma * self.lam * nextnonterminal * last_gae_lam mb_returns = mb_advs + mb_values mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs = \ map(swap_and_flatten, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)) return mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, mb_states, ep_infos, true_reward # obs, returns, masks, actions, values, neglogpacs, states = runner.run() def swap_and_flatten(arr): """ swap and then flatten axes 0 and 1 :param arr: (np.ndarray) :return: (np.ndarray) """ shape = arr.shape return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) def constfn(val): """ Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication) :param val: (float) :return: (function) """ def func(_): return val return func def safe_mean(arr): """ Compute the mean of an array if there is at least one element. For empty array, return zero. It is used for logging only. :param arr: (np.ndarray) :return: (float) """ return np.nan if len(arr) == 0 else np.mean(arr)