from collections import OrderedDict
import numpy as np
from gym import spaces
# Important: gym mixes up ordered and unordered keys
# and the Dict space may return a different order of keys that the actual one
KEY_ORDER = ['observation', 'achieved_goal', 'desired_goal']
[docs]class HERGoalEnvWrapper(object):
"""
A wrapper that allow to use dict observation space (coming from GoalEnv) with
the RL algorithms.
It assumes that all the spaces of the dict space are of the same type.
:param env: (gym.GoalEnv)
"""
def __init__(self, env):
super(HERGoalEnvWrapper, self).__init__()
self.env = env
self.metadata = self.env.metadata
self.action_space = env.action_space
self.spaces = list(env.observation_space.spaces.values())
# Check that all spaces are of the same type
# (current limitation of the wrapper)
space_types = [type(env.observation_space.spaces[key]) for key in KEY_ORDER]
assert len(set(space_types)) == 1, "The spaces for goal and observation"\
" must be of the same type"
if isinstance(self.spaces[0], spaces.Discrete):
self.obs_dim = 1
self.goal_dim = 1
else:
goal_space_shape = env.observation_space.spaces['achieved_goal'].shape
self.obs_dim = env.observation_space.spaces['observation'].shape[0]
self.goal_dim = goal_space_shape[0]
if len(goal_space_shape) == 2:
assert goal_space_shape[1] == 1, "Only 1D observation spaces are supported yet"
else:
assert len(goal_space_shape) == 1, "Only 1D observation spaces are supported yet"
if isinstance(self.spaces[0], spaces.MultiBinary):
total_dim = self.obs_dim + 2 * self.goal_dim
self.observation_space = spaces.MultiBinary(total_dim)
elif isinstance(self.spaces[0], spaces.Box):
lows = np.concatenate([space.low for space in self.spaces])
highs = np.concatenate([space.high for space in self.spaces])
self.observation_space = spaces.Box(lows, highs, dtype=np.float32)
elif isinstance(self.spaces[0], spaces.Discrete):
dimensions = [env.observation_space.spaces[key].n for key in KEY_ORDER]
self.observation_space = spaces.MultiDiscrete(dimensions)
else:
raise NotImplementedError("{} space is not supported".format(type(self.spaces[0])))
[docs] def convert_dict_to_obs(self, obs_dict):
"""
:param obs_dict: (dict<np.ndarray>)
:return: (np.ndarray)
"""
# Note: achieved goal is not removed from the observation
# this is helpful to have a revertible transformation
if isinstance(self.observation_space, spaces.MultiDiscrete):
# Special case for multidiscrete
return np.concatenate([[int(obs_dict[key])] for key in KEY_ORDER])
return np.concatenate([obs_dict[key] for key in KEY_ORDER])
[docs] def convert_obs_to_dict(self, observations):
"""
Inverse operation of convert_dict_to_obs
:param observations: (np.ndarray)
:return: (OrderedDict<np.ndarray>)
"""
return OrderedDict([
('observation', observations[:self.obs_dim]),
('achieved_goal', observations[self.obs_dim:self.obs_dim + self.goal_dim]),
('desired_goal', observations[self.obs_dim + self.goal_dim:]),
])
def step(self, action):
obs, reward, done, info = self.env.step(action)
return self.convert_dict_to_obs(obs), reward, done, info
def seed(self, seed=None):
return self.env.seed(seed)
def reset(self):
return self.convert_dict_to_obs(self.env.reset())
def compute_reward(self, achieved_goal, desired_goal, info):
return self.env.compute_reward(achieved_goal, desired_goal, info)
def render(self, mode='human'):
return self.env.render(mode)
def close(self):
return self.env.close()