Source code for stable_baselines.common.base_class

from abc import ABC, abstractmethod
import os
import glob

import cloudpickle
import numpy as np
import gym
import tensorflow as tf

from stable_baselines.common import set_global_seeds
from stable_baselines.common.policies import LstmPolicy, get_policy_from_name, ActorCriticPolicy
from stable_baselines.common.vec_env import VecEnvWrapper, VecEnv, DummyVecEnv
from stable_baselines import logger


[docs]class BaseRLModel(ABC): """ The base RL model :param policy: (BasePolicy) Policy object :param env: (Gym environment) The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param requires_vec_env: (bool) Does this model require a vectorized environment :param policy_base: (BasePolicy) the base policy used by this method """ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base): if isinstance(policy, str): self.policy = get_policy_from_name(policy_base, policy) else: self.policy = policy self.env = env self.verbose = verbose self._requires_vec_env = requires_vec_env self.observation_space = None self.action_space = None self.n_envs = None self._vectorize_action = False if env is not None: if isinstance(env, str): if self.verbose >= 1: print("Creating environment from the given name, wrapped in a DummyVecEnv.") self.env = env = DummyVecEnv([lambda: gym.make(env)]) self.observation_space = env.observation_space self.action_space = env.action_space if requires_vec_env: if isinstance(env, VecEnv): self.n_envs = env.num_envs else: raise ValueError("Error: the model requires a vectorized environment, please use a VecEnv wrapper.") else: if isinstance(env, VecEnv): if env.num_envs == 1: self.env = _UnvecWrapper(env) self._vectorize_action = True else: raise ValueError("Error: the model requires a non vectorized environment or a single vectorized" " environment.") self.n_envs = 1
[docs] def get_env(self): """ returns the current environment (can be None if not defined) :return: (Gym Environment) The current environment """ return self.env
[docs] def set_env(self, env): """ Checks the validity of the environment, and if it is coherent, set it as the current environment. :param env: (Gym Environment) The environment for learning a policy """ if env is None and self.env is None: if self.verbose >= 1: print("Loading a model without an environment, " "this model cannot be trained until it has a valid environment.") return elif env is None: raise ValueError("Error: trying to replace the current environment with None") # sanity checking the environment assert self.observation_space == env.observation_space, \ "Error: the environment passed must have at least the same observation space as the model was trained on." assert self.action_space == env.action_space, \ "Error: the environment passed must have at least the same action space as the model was trained on." if self._requires_vec_env: assert isinstance(env, VecEnv), \ "Error: the environment passed is not a vectorized environment, however {} requires it".format( self.__class__.__name__) assert not issubclass(self.policy, LstmPolicy) or self.n_envs == env.num_envs, \ "Error: the environment passed must have the same number of environments as the model was trained on." \ "This is due to the Lstm policy not being capable of changing the number of environments." self.n_envs = env.num_envs else: # for models that dont want vectorized environment, check if they make sense and adapt them. # Otherwise tell the user about this issue if isinstance(env, VecEnv): if env.num_envs == 1: env = _UnvecWrapper(env) self._vectorize_action = True else: raise ValueError("Error: the model requires a non vectorized environment or a single vectorized " "environment.") else: self._vectorize_action = False self.n_envs = 1 self.env = env
[docs] @abstractmethod def setup_model(self): """ Create all the functions and tensorflow graphs necessary to train the model """ pass
def _setup_learn(self, seed): """ check the environment, set the seed, and set the logger :param seed: (int) the seed value """ if self.env is None: raise ValueError("Error: cannot train the model without a valid environment, please set an environment with" "set_env(self, env) method.") if seed is not None: set_global_seeds(seed)
[docs] @abstractmethod def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run"): """ Return a trained model. :param total_timesteps: (int) The total number of samples to train on :param seed: (int) The initial seed for training, if None: keep current seed :param callback: (function (dict, dict)) function called at every steps with state of the algorithm. It takes the local and global variables. :param log_interval: (int) The number of timesteps before logging. :param tb_log_name: (str) the name of the run for tensorboard log :return: (BaseRLModel) the trained model """ pass
[docs] @abstractmethod def predict(self, observation, state=None, mask=None, deterministic=False): """ Get the model's action from an observation :param observation: (np.ndarray) the input observation :param state: (np.ndarray) The last states (can be None, used in recurrent policies) :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) :param deterministic: (bool) Whether or not to return deterministic actions. :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) """ pass
[docs] @abstractmethod def action_probability(self, observation, state=None, mask=None): """ Get the model's action probability distribution from an observation :param observation: (np.ndarray) the input observation :param state: (np.ndarray) The last states (can be None, used in recurrent policies) :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) :return: (np.ndarray) the model's action probability distribution """ pass
[docs] @abstractmethod def save(self, save_path): """ Save the current parameters to file :param save_path: (str) the save location """ # self._save_to_file(save_path, data={}, params=None) raise NotImplementedError()
[docs] @classmethod @abstractmethod def load(cls, load_path, env=None, **kwargs): """ Load the model from file :param load_path: (str) the saved parameter location :param env: (Gym Envrionment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model) :param kwargs: extra arguments to change the model when loading """ # data, param = cls._load_from_file(load_path) raise NotImplementedError()
@staticmethod def _save_to_file(save_path, data=None, params=None): _, ext = os.path.splitext(save_path) if ext == "": save_path += ".pkl" with open(save_path, "wb") as file: cloudpickle.dump((data, params), file) @staticmethod def _load_from_file(load_path): if not os.path.exists(load_path): if os.path.exists(load_path + ".pkl"): load_path += ".pkl" else: raise ValueError("Error: the file {} could not be found".format(load_path)) with open(load_path, "rb") as file: data, params = cloudpickle.load(file) return data, params @staticmethod def _softmax(x_input): """ An implementation of softmax. :param x_input: (numpy float) input vector :return: (numpy float) output vector """ x_exp = np.exp(x_input.T - np.max(x_input.T, axis=0)) return (x_exp / x_exp.sum(axis=0)).T @staticmethod def _is_vectorized_observation(observation, observation_space): """ For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: (np.ndarray) the input observation to validate :param observation_space: (gym.spaces) the observation space :return: (bool) whether the given observation is vectorized or not """ if isinstance(observation_space, gym.spaces.Box): if observation.shape == observation_space.shape: return False elif observation.shape[1:] == observation_space.shape: return True else: raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + "Box environment, please use {} ".format(observation_space.shape) + "or (n_env, {}) for the observation shape." .format(", ".join(map(str, observation_space.shape)))) elif isinstance(observation_space, gym.spaces.Discrete): if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' return False elif len(observation.shape) == 1: return True else: raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") elif isinstance(observation_space, gym.spaces.MultiDiscrete): if observation.shape == (len(observation_space.nvec),): return False elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): return True else: raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) + "environment, please use ({},) or ".format(len(observation_space.nvec)) + "(n_env, {}) for the observation shape.".format(len(observation_space.nvec))) elif isinstance(observation_space, gym.spaces.MultiBinary): if observation.shape == (observation_space.n,): return False elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: return True else: raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) + "environment, please use ({},) or ".format(observation_space.n) + "(n_env, {}) for the observation shape.".format(observation_space.n)) else: raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}." .format(observation_space))
class ActorCriticRLModel(BaseRLModel): """ The base class for Actor critic model :param policy: (BasePolicy) Policy object :param env: (Gym environment) The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param policy_base: (BasePolicy) the base policy used by this method (default=ActorCriticPolicy) :param requires_vec_env: (bool) Does this model require a vectorized environment """ def __init__(self, policy, env, _init_setup_model, verbose=0, policy_base=ActorCriticPolicy, requires_vec_env=False): super(ActorCriticRLModel, self).__init__(policy, env, verbose=verbose, requires_vec_env=requires_vec_env, policy_base=policy_base) self.sess = None self.initial_state = None self.step = None self.proba_step = None self.params = None @abstractmethod def setup_model(self): pass @abstractmethod def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run"): pass def predict(self, observation, state=None, mask=None, deterministic=False): if state is None: state = self.initial_state if mask is None: mask = [False for _ in range(self.n_envs)] observation = np.array(observation) vectorized_env = self._is_vectorized_observation(observation, self.observation_space) observation = observation.reshape((-1,) + self.observation_space.shape) actions, _, states, _ = self.step(observation, state, mask, deterministic=deterministic) if not vectorized_env: if state is not None: raise ValueError("Error: The environment must be vectorized when using recurrent policies.") actions = actions[0] return actions, states def action_probability(self, observation, state=None, mask=None): if state is None: state = self.initial_state if mask is None: mask = [False for _ in range(self.n_envs)] observation = np.array(observation) vectorized_env = self._is_vectorized_observation(observation, self.observation_space) observation = observation.reshape((-1,) + self.observation_space.shape) actions_proba = self.proba_step(observation, state, mask) if not vectorized_env: if state is not None: raise ValueError("Error: The environment must be vectorized when using recurrent policies.") actions_proba = actions_proba[0] return actions_proba @abstractmethod def save(self, save_path): pass @classmethod def load(cls, load_path, env=None, **kwargs): data, params = cls._load_from_file(load_path) model = cls(policy=data["policy"], env=None, _init_setup_model=False) model.__dict__.update(data) model.__dict__.update(kwargs) model.set_env(env) model.setup_model() restores = [] for param, loaded_p in zip(model.params, params): restores.append(param.assign(loaded_p)) model.sess.run(restores) return model class OffPolicyRLModel(BaseRLModel): """ The base class for off policy RL model :param policy: (BasePolicy) Policy object :param env: (Gym environment) The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param replay_buffer: (ReplayBuffer) the type of replay buffer :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param requires_vec_env: (bool) Does this model require a vectorized environment :param policy_base: (BasePolicy) the base policy used by this method """ def __init__(self, policy, env, replay_buffer, verbose=0, *, requires_vec_env, policy_base): super(OffPolicyRLModel, self).__init__(policy, env, verbose=verbose, requires_vec_env=requires_vec_env, policy_base=policy_base) self.replay_buffer = replay_buffer @abstractmethod def setup_model(self): pass @abstractmethod def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run"): pass @abstractmethod def predict(self, observation, state=None, mask=None, deterministic=False): pass @abstractmethod def action_probability(self, observation, state=None, mask=None): pass @abstractmethod def save(self, save_path): pass @classmethod @abstractmethod def load(cls, load_path, env=None, **kwargs): pass class _UnvecWrapper(VecEnvWrapper): def __init__(self, venv): """ Unvectorize a vectorized environment, for vectorized environment that only have one environment :param venv: (VecEnv) the vectorized environment to wrap """ super().__init__(venv) assert venv.num_envs == 1, "Error: cannot unwrap a environment wrapper that has more than one environment." def reset(self): return self.venv.reset()[0] def step_async(self, actions): self.venv.step_async([actions]) def step_wait(self): actions, values, states, information = self.venv.step_wait() return actions[0], float(values[0]), states[0], information[0] def render(self, mode='human'): return self.venv.render(mode=mode) class SetVerbosity: def __init__(self, verbose=0): """ define a region of code for certain level of verbosity :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug """ self.verbose = verbose def __enter__(self): self.tf_level = os.environ.get('TF_CPP_MIN_LOG_LEVEL', '0') self.log_level = logger.get_level() self.gym_level = gym.logger.MIN_LEVEL if self.verbose <= 1: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' if self.verbose <= 0: logger.set_level(logger.DISABLED) gym.logger.set_level(gym.logger.DISABLED) def __exit__(self, exc_type, exc_val, exc_tb): if self.verbose <= 1: os.environ['TF_CPP_MIN_LOG_LEVEL'] = self.tf_level if self.verbose <= 0: logger.set_level(self.log_level) gym.logger.set_level(self.gym_level) class TensorboardWriter: def __init__(self, graph, tensorboard_log_path, tb_log_name): """ Create a Tensorboard writer for a code segment, and saves it to the log directory as its own run :param graph: (Tensorflow Graph) the model graph :param tensorboard_log_path: (str) the save path for the log (can be None for no logging) :param tb_log_name: (str) the name of the run for tensorboard log """ self.graph = graph self.tensorboard_log_path = tensorboard_log_path self.tb_log_name = tb_log_name self.writer = None def __enter__(self): if self.tensorboard_log_path is not None: save_path = os.path.join(self.tensorboard_log_path, "{}_{}".format(self.tb_log_name, self._get_latest_run_id() + 1)) self.writer = tf.summary.FileWriter(save_path, graph=self.graph) return self.writer def _get_latest_run_id(self): """ returns the latest run number for the given log name and log path, by finding the greatest number in the directories. :return: (int) latest run number """ max_run_id = 0 for path in glob.glob(self.tensorboard_log_path + "/{}_[0-9]*".format(self.tb_log_name)): file_name = path.split("/")[-1] ext = file_name.split("_")[-1] if self.tb_log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: max_run_id = int(ext) return max_run_id def __exit__(self, exc_type, exc_val, exc_tb): if self.writer is not None: self.writer.add_graph(self.graph) self.writer.flush()