Source code for stable_baselines.common.cmd_util

"""
Helpers for scripts like run_atari.py.
"""

import os
import warnings

import gym

from stable_baselines import logger
from stable_baselines.bench import Monitor
from stable_baselines.common.misc_util import set_global_seeds
from stable_baselines.common.atari_wrappers import make_atari, wrap_deepmind
from stable_baselines.common.misc_util import mpi_rank_or_zero
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv


[docs]def make_vec_env(env_id, n_envs=1, seed=None, start_index=0, monitor_dir=None, wrapper_class=None, env_kwargs=None, vec_env_cls=None, vec_env_kwargs=None): """ Create a wrapped, monitored `VecEnv`. By default it uses a `DummyVecEnv` which is usually faster than a `SubprocVecEnv`. :param env_id: (str or Type[gym.Env]) the environment ID or the environment class :param n_envs: (int) the number of environments you wish to have in parallel :param seed: (int) the initial seed for the random number generator :param start_index: (int) start rank index :param monitor_dir: (str) Path to a folder where the monitor files will be saved. If None, no file will be written, however, the env will still be wrapped in a Monitor wrapper to provide additional information about training. :param wrapper_class: (gym.Wrapper or callable) Additional wrapper to use on the environment. This can also be a function with single argument that wraps the environment in many things. :param env_kwargs: (dict) Optional keyword argument to pass to the env constructor :param vec_env_cls: (Type[VecEnv]) A custom `VecEnv` class constructor. Default: None. :param vec_env_kwargs: (dict) Keyword arguments to pass to the `VecEnv` class constructor. :return: (VecEnv) The wrapped environment """ env_kwargs = {} if env_kwargs is None else env_kwargs vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs def make_env(rank): def _init(): if isinstance(env_id, str): env = gym.make(env_id) if len(env_kwargs) > 0: warnings.warn("No environment class was passed (only an env ID) so `env_kwargs` will be ignored") else: env = env_id(**env_kwargs) if seed is not None: env.seed(seed + rank) env.action_space.seed(seed + rank) # Wrap the env in a Monitor wrapper # to have additional training information monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None # Create the monitor folder if needed if monitor_path is not None: os.makedirs(monitor_dir, exist_ok=True) env = Monitor(env, filename=monitor_path) # Optionally, wrap the environment with the provided wrapper if wrapper_class is not None: env = wrapper_class(env) return env return _init # No custom VecEnv is passed if vec_env_cls is None: # Default: use a DummyVecEnv vec_env_cls = DummyVecEnv return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
[docs]def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0, allow_early_resets=True, start_method=None, use_subprocess=False): """ Create a wrapped, monitored VecEnv for Atari. :param env_id: (str) the environment ID :param num_env: (int) the number of environment you wish to have in subprocesses :param seed: (int) the initial seed for RNG :param wrapper_kwargs: (dict) the parameters for wrap_deepmind function :param start_index: (int) start rank index :param allow_early_resets: (bool) allows early reset of the environment :param start_method: (str) method used to start the subprocesses. See SubprocVecEnv doc for more information :param use_subprocess: (bool) Whether to use `SubprocVecEnv` or `DummyVecEnv` when `num_env` > 1, `DummyVecEnv` is usually faster. Default: False :return: (VecEnv) The atari environment """ if wrapper_kwargs is None: wrapper_kwargs = {} def make_env(rank): def _thunk(): env = make_atari(env_id) env.seed(seed + rank) env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), allow_early_resets=allow_early_resets) return wrap_deepmind(env, **wrapper_kwargs) return _thunk set_global_seeds(seed) # When using one environment, no need to start subprocesses if num_env == 1 or not use_subprocess: return DummyVecEnv([make_env(i + start_index) for i in range(num_env)]) return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)], start_method=start_method)
[docs]def make_mujoco_env(env_id, seed, allow_early_resets=True): """ Create a wrapped, monitored gym.Env for MuJoCo. :param env_id: (str) the environment ID :param seed: (int) the initial seed for RNG :param allow_early_resets: (bool) allows early reset of the environment :return: (Gym Environment) The mujoco environment """ set_global_seeds(seed + 10000 * mpi_rank_or_zero()) env = gym.make(env_id) env = Monitor(env, os.path.join(logger.get_dir(), '0'), allow_early_resets=allow_early_resets) env.seed(seed) return env
[docs]def make_robotics_env(env_id, seed, rank=0, allow_early_resets=True): """ Create a wrapped, monitored gym.Env for MuJoCo. :param env_id: (str) the environment ID :param seed: (int) the initial seed for RNG :param rank: (int) the rank of the environment (for logging) :param allow_early_resets: (bool) allows early reset of the environment :return: (Gym Environment) The robotic environment """ set_global_seeds(seed) env = gym.make(env_id) keys = ['observation', 'desired_goal'] # TODO: remove try-except once most users are running modern Gym try: # for modern Gym (>=0.15.4) from gym.wrappers import FilterObservation, FlattenObservation env = FlattenObservation(FilterObservation(env, keys)) except ImportError: # for older gym (<=0.15.3) from gym.wrappers import FlattenDictWrapper # pytype:disable=import-error env = FlattenDictWrapper(env, keys) env = Monitor( env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), info_keywords=('is_success',), allow_early_resets=allow_early_resets) env.seed(seed) return env
[docs]def arg_parser(): """ Create an empty argparse.ArgumentParser. :return: (ArgumentParser) """ import argparse return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
[docs]def atari_arg_parser(): """ Create an argparse.ArgumentParser for run_atari.py. :return: (ArgumentParser) parser {'--env': 'BreakoutNoFrameskip-v4', '--seed': 0, '--num-timesteps': int(1e7)} """ parser = arg_parser() parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--num-timesteps', type=int, default=int(1e7)) return parser
[docs]def mujoco_arg_parser(): """ Create an argparse.ArgumentParser for run_mujoco.py. :return: (ArgumentParser) parser {'--env': 'Reacher-v2', '--seed': 0, '--num-timesteps': int(1e6), '--play': False} """ parser = arg_parser() parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--num-timesteps', type=int, default=int(1e6)) parser.add_argument('--play', default=False, action='store_true') return parser
[docs]def robotics_arg_parser(): """ Create an argparse.ArgumentParser for run_mujoco.py. :return: (ArgumentParser) parser {'--env': 'FetchReach-v0', '--seed': 0, '--num-timesteps': int(1e6)} """ parser = arg_parser() parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--num-timesteps', type=int, default=int(1e6)) return parser