from abc import ABC
import numpy as np
import tensorflow as tf
from gym.spaces import Discrete
from stable_baselines.a2c.utils import conv, linear, conv_to_fc, batch_to_seq, seq_to_batch, lstm
from stable_baselines.common.distributions import make_proba_dist_type
from stable_baselines.common.input import observation_input
def nature_cnn(scaled_images, **kwargs):
"""
CNN from Nature paper.
:param scaled_images: (TensorFlow Tensor) Image input placeholder
:param kwargs: (dict) Extra keywords parameters for the convolutional layers of the CNN
:return: (TensorFlow Tensor) The CNN output layer
"""
activ = tf.nn.relu
layer_1 = activ(conv(scaled_images, 'c1', n_filters=32, filter_size=8, stride=4, init_scale=np.sqrt(2), **kwargs))
layer_2 = activ(conv(layer_1, 'c2', n_filters=64, filter_size=4, stride=2, init_scale=np.sqrt(2), **kwargs))
layer_3 = activ(conv(layer_2, 'c3', n_filters=64, filter_size=3, stride=1, init_scale=np.sqrt(2), **kwargs))
layer_3 = conv_to_fc(layer_3)
return activ(linear(layer_3, 'fc1', n_hidden=512, init_scale=np.sqrt(2)))
class BasePolicy(ABC):
"""
The base policy object
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param scale: (bool) whether or not to scale the input
:param obs_phs: (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder
and the processed observation placeholder respectivly
:param add_action_ph: (bool) whether or not to create an action placeholder
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, scale=False,
obs_phs=None, add_action_ph=False):
self.n_env = n_env
self.n_steps = n_steps
with tf.variable_scope("input", reuse=False):
if obs_phs is None:
self.obs_ph, self.processed_x = observation_input(ob_space, n_batch, scale=scale)
else:
self.obs_ph, self.processed_x = obs_phs
self.masks_ph = tf.placeholder(tf.float32, [n_batch], name="masks_ph") # mask (done t-1)
self.states_ph = tf.placeholder(tf.float32, [self.n_env, n_lstm * 2], name="states_ph") # states
self.action_ph = None
if add_action_ph:
self.action_ph = tf.placeholder(dtype=ac_space.dtype, shape=(None,) + ac_space.shape, name="action_ph")
self.sess = sess
self.reuse = reuse
self.ob_space = ob_space
self.ac_space = ac_space
def step(self, obs, state=None, mask=None):
"""
Returns the policy for a single step
:param obs: ([float] or [int]) The current observation of the environment
:param state: ([float]) The last states (used in recurrent policies)
:param mask: ([float]) The last masks (used in recurrent policies)
:return: ([float], [float], [float], [float]) actions, values, states, neglogp
"""
raise NotImplementedError
def proba_step(self, obs, state=None, mask=None):
"""
Returns the action probability for a single step
:param obs: ([float] or [int]) The current observation of the environment
:param state: ([float]) The last states (used in recurrent policies)
:param mask: ([float]) The last masks (used in recurrent policies)
:return: ([float]) the action probability
"""
raise NotImplementedError
[docs]class ActorCriticPolicy(BasePolicy):
"""
Policy object that implements actor critic
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param scale: (bool) whether or not to scale the input
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, scale=False):
super(ActorCriticPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=n_lstm,
reuse=reuse, scale=scale)
self.pdtype = make_proba_dist_type(ac_space)
self.is_discrete = isinstance(ac_space, Discrete)
self.policy = None
self.proba_distribution = None
self.value_fn = None
self.deterministic_action = None
self.initial_state = None
def _setup_init(self):
"""
sets up the distibutions, actions, and value
"""
with tf.variable_scope("output", reuse=True):
assert self.policy is not None and self.proba_distribution is not None and self.value_fn is not None
self.action = self.proba_distribution.sample()
self.deterministic_action = self.proba_distribution.mode()
self.neglogp = self.proba_distribution.neglogp(self.action)
self.policy_proba = self.policy
if self.is_discrete:
self.policy_proba = tf.nn.softmax(self.policy_proba)
self._value = self.value_fn[:, 0]
[docs] def step(self, obs, state=None, mask=None, deterministic=False):
"""
Returns the policy for a single step
:param obs: ([float] or [int]) The current observation of the environment
:param state: ([float]) The last states (used in recurrent policies)
:param mask: ([float]) The last masks (used in recurrent policies)
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: ([float], [float], [float], [float]) actions, values, states, neglogp
"""
raise NotImplementedError
[docs] def proba_step(self, obs, state=None, mask=None):
"""
Returns the action probability for a single step
:param obs: ([float] or [int]) The current observation of the environment
:param state: ([float]) The last states (used in recurrent policies)
:param mask: ([float]) The last masks (used in recurrent policies)
:return: ([float]) the action probability
"""
raise NotImplementedError
[docs] def value(self, obs, state=None, mask=None):
"""
Returns the value for a single step
:param obs: ([float] or [int]) The current observation of the environment
:param state: ([float]) The last states (used in recurrent policies)
:param mask: ([float]) The last masks (used in recurrent policies)
:return: ([float]) The associated value of the action
"""
raise NotImplementedError
[docs]class LstmPolicy(ActorCriticPolicy):
"""
Policy object that implements actor critic, using LSTMs.
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param layers: ([int]) The size of the Neural network before the LSTM layer (if None, default to [64, 64])
:param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
:param layer_norm: (bool) Whether or not to use layer normalizing LSTMs
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, layers=None,
cnn_extractor=nature_cnn, layer_norm=False, feature_extraction="cnn", **kwargs):
super(LstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
scale=(feature_extraction == "cnn"))
if layers is None:
layers = [64, 64]
with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
extracted_features = cnn_extractor(self.processed_x, **kwargs)
else:
activ = tf.tanh
extracted_features = tf.layers.flatten(self.processed_x)
for i, layer_size in enumerate(layers):
extracted_features = activ(linear(extracted_features, 'pi_fc' + str(i), n_hidden=layer_size,
init_scale=np.sqrt(2)))
input_sequence = batch_to_seq(extracted_features, self.n_env, n_steps)
masks = batch_to_seq(self.masks_ph, self.n_env, n_steps)
rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
rnn_output = seq_to_batch(rnn_output)
value_fn = linear(rnn_output, 'vf', 1)
self.proba_distribution, self.policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(rnn_output, rnn_output)
self.value_fn = value_fn
self.initial_state = np.zeros((self.n_env, n_lstm * 2), dtype=np.float32)
self._setup_init()
[docs] def step(self, obs, state=None, mask=None, deterministic=False):
if deterministic:
return self.sess.run([self.deterministic_action, self._value, self.snew, self.neglogp],
{self.obs_ph: obs, self.states_ph: state, self.masks_ph: mask})
else:
return self.sess.run([self.action, self._value, self.snew, self.neglogp],
{self.obs_ph: obs, self.states_ph: state, self.masks_ph: mask})
[docs] def proba_step(self, obs, state=None, mask=None):
return self.sess.run(self.policy_proba, {self.obs_ph: obs, self.states_ph: state, self.masks_ph: mask})
[docs] def value(self, obs, state=None, mask=None):
return self.sess.run(self._value, {self.obs_ph: obs, self.states_ph: state, self.masks_ph: mask})
[docs]class FeedForwardPolicy(ActorCriticPolicy):
"""
Policy object that implements actor critic, using a feed forward neural network.
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param reuse: (bool) If the policy is reusable or not
:param layers: ([int]) The size of the Neural network for the policy (if None, default to [64, 64])
:param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, layers=None,
cnn_extractor=nature_cnn, feature_extraction="cnn", **kwargs):
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256,
reuse=reuse, scale=(feature_extraction == "cnn"))
if layers is None:
layers = [64, 64]
with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
extracted_features = cnn_extractor(self.processed_x, **kwargs)
value_fn = linear(extracted_features, 'vf', 1)
pi_latent = extracted_features
vf_latent = extracted_features
else:
activ = tf.tanh
processed_x = tf.layers.flatten(self.processed_x)
pi_h = processed_x
vf_h = processed_x
for i, layer_size in enumerate(layers):
pi_h = activ(linear(pi_h, 'pi_fc' + str(i), n_hidden=layer_size, init_scale=np.sqrt(2)))
vf_h = activ(linear(vf_h, 'vf_fc' + str(i), n_hidden=layer_size, init_scale=np.sqrt(2)))
value_fn = linear(vf_h, 'vf', 1)
pi_latent = pi_h
vf_latent = vf_h
self.proba_distribution, self.policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01)
self.value_fn = value_fn
self.initial_state = None
self._setup_init()
[docs] def step(self, obs, state=None, mask=None, deterministic=False):
if deterministic:
action, value, neglogp = self.sess.run([self.deterministic_action, self._value, self.neglogp],
{self.obs_ph: obs})
else:
action, value, neglogp = self.sess.run([self.action, self._value, self.neglogp],
{self.obs_ph: obs})
return action, value, self.initial_state, neglogp
[docs] def proba_step(self, obs, state=None, mask=None):
return self.sess.run(self.policy_proba, {self.obs_ph: obs})
[docs] def value(self, obs, state=None, mask=None):
return self.sess.run(self._value, {self.obs_ph: obs})
[docs]class CnnPolicy(FeedForwardPolicy):
"""
Policy object that implements actor critic, using a CNN (the nature CNN)
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param reuse: (bool) If the policy is reusable or not
:param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **_kwargs):
super(CnnPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
feature_extraction="cnn", **_kwargs)
[docs]class CnnLstmPolicy(LstmPolicy):
"""
Policy object that implements actor critic, using LSTMs with a CNN feature extraction
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, **_kwargs):
super(CnnLstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
layer_norm=False, feature_extraction="cnn", **_kwargs)
[docs]class CnnLnLstmPolicy(LstmPolicy):
"""
Policy object that implements actor critic, using a layer normalized LSTMs with a CNN feature extraction
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, **_kwargs):
super(CnnLnLstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
layer_norm=True, feature_extraction="cnn", **_kwargs)
[docs]class MlpPolicy(FeedForwardPolicy):
"""
Policy object that implements actor critic, using a MLP (2 layers of 64)
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param reuse: (bool) If the policy is reusable or not
:param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **_kwargs):
super(MlpPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
feature_extraction="mlp", **_kwargs)
[docs]class MlpLstmPolicy(LstmPolicy):
"""
Policy object that implements actor critic, using LSTMs with a MLP feature extraction
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, **_kwargs):
super(MlpLstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
layer_norm=False, feature_extraction="mlp", **_kwargs)
[docs]class MlpLnLstmPolicy(LstmPolicy):
"""
Policy object that implements actor critic, using a layer normalized LSTMs with a MLP feature extraction
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, **_kwargs):
super(MlpLnLstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
layer_norm=True, feature_extraction="mlp", **_kwargs)
_policy_registry = {
ActorCriticPolicy: {
"CnnPolicy": CnnPolicy,
"CnnLstmPolicy": CnnLstmPolicy,
"CnnLnLstmPolicy": CnnLnLstmPolicy,
"MlpPolicy": MlpPolicy,
"MlpLstmPolicy": MlpLstmPolicy,
"MlpLnLstmPolicy": MlpLnLstmPolicy,
}
}
def get_policy_from_name(base_policy_type, name):
"""
returns the registed policy from the base type and name
:param base_policy_type: (BasePolicy) the base policy object
:param name: (str) the policy name
:return: (base_policy_type) the policy
"""
if base_policy_type not in _policy_registry:
raise ValueError("Error: the policy type {} is not registered!".format(base_policy_type))
if name not in _policy_registry[base_policy_type]:
raise ValueError("Error: unknown policy type {}, the only registed policy type are: {}!"
.format(name, list(_policy_registry[base_policy_type].keys())))
return _policy_registry[base_policy_type][name]
def register_policy(name, policy):
"""
returns the registed policy from the base type and name
:param name: (str) the policy name
:param policy: (subclass of BasePolicy) the policy
"""
sub_class = None
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
if sub_class is None:
raise ValueError("Error: the policy {} is not of any known subclasses of BasePolicy!".format(policy))
if sub_class not in _policy_registry:
_policy_registry[sub_class] = {}
if name in _policy_registry[sub_class]:
raise ValueError("Error: the name {} is alreay registered for a different policy, will not override."
.format(name))
_policy_registry[sub_class][name] = policy