import queue
import time
from multiprocessing import Queue, Process
import cv2
import numpy as np
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from stable_baselines import logger
[docs]class ExpertDataset(object):
"""
Dataset for using behavior cloning or GAIL.
Data structure of the expert dataset, an ".npz" archive:
the data is saved in python dictionary format with keys: 'actions', 'episode_returns',
'rewards', 'obs', 'episode_starts'
In case of images, 'obs' contains the relative path to the images.
:param expert_path: (str) the path to trajectory data (.npz file)
:param train_fraction: (float) the train validation split (0 to 1)
for pre-training using behavior cloning (BC)
:param batch_size: (int) the minibatch size for behavior cloning
:param traj_limitation: (int) the number of trajectory to use (if -1, load all)
:param randomize: (bool) if the dataset should be shuffled
:param verbose: (int) Verbosity
:param sequential_preprocessing: (bool) Do not use subprocess to preprocess
the data (slower but use less memory for the CI)
"""
def __init__(self, expert_path, train_fraction=0.7, batch_size=64,
traj_limitation=-1, randomize=True, verbose=1,
sequential_preprocessing=False):
traj_data = np.load(expert_path)
if verbose > 0:
for key, val in traj_data.items():
print(key, val.shape)
# Array of bool where episode_starts[i] = True for each new episode
episode_starts = traj_data['episode_starts']
traj_limit_idx = len(traj_data['obs'])
if traj_limitation > 0:
n_episodes = 0
# Retrieve the index corresponding
# to the traj_limitation trajectory
for idx, episode_start in enumerate(episode_starts):
n_episodes += int(episode_start)
if n_episodes == (traj_limitation + 1):
traj_limit_idx = idx - 1
observations = traj_data['obs'][:traj_limit_idx]
actions = traj_data['actions'][:traj_limit_idx]
# obs, actions: shape (N * L, ) + S
# where N = # episodes, L = episode length
# and S is the environment observation/action space.
# S = (1, ) for discrete space
# Flatten to (N * L, prod(S))
if len(observations.shape) > 2:
observations = np.reshape(observations, [-1, np.prod(observations.shape[1:])])
if len(actions.shape) > 2:
actions = np.reshape(actions, [-1, np.prod(actions.shape[1:])])
indices = np.random.permutation(len(observations)).astype(np.int64)
# Train/Validation split when using behavior cloning
train_indices = indices[:int(train_fraction * len(indices))]
val_indices = indices[int(train_fraction * len(indices)):]
assert len(train_indices) > 0, "No sample for the training set"
assert len(val_indices) > 0, "No sample for the validation set"
self.observations = observations
self.actions = actions
self.returns = traj_data['episode_returns'][:traj_limit_idx]
self.avg_ret = sum(self.returns) / len(self.returns)
self.std_ret = np.std(np.array(self.returns))
self.verbose = verbose
assert len(self.observations) == len(self.actions), "The number of actions and observations differ " \
"please check your expert dataset"
self.num_traj = min(traj_limitation, np.sum(episode_starts))
self.num_transition = len(self.observations)
self.randomize = randomize
self.sequential_preprocessing = sequential_preprocessing
self.dataloader = None
self.train_loader = DataLoader(train_indices, self.observations, self.actions, batch_size,
shuffle=self.randomize, start_process=False,
sequential=sequential_preprocessing)
self.val_loader = DataLoader(val_indices, self.observations, self.actions, batch_size,
shuffle=self.randomize, start_process=False,
sequential=sequential_preprocessing)
if self.verbose >= 1:
self.log_info()
[docs] def init_dataloader(self, batch_size):
"""
Initialize the dataloader used by GAIL.
:param batch_size: (int)
"""
indices = np.random.permutation(len(self.observations)).astype(np.int64)
self.dataloader = DataLoader(indices, self.observations, self.actions, batch_size,
shuffle=self.randomize, start_process=False,
sequential=self.sequential_preprocessing)
def __del__(self):
del self.dataloader, self.train_loader, self.val_loader
[docs] def prepare_pickling(self):
"""
Exit processes in order to pickle the dataset.
"""
self.dataloader, self.train_loader, self.val_loader = None, None, None
[docs] def log_info(self):
"""
Log the information of the dataset.
"""
logger.log("Total trajectories: {}".format(self.num_traj))
logger.log("Total transitions: {}".format(self.num_transition))
logger.log("Average returns: {}".format(self.avg_ret))
logger.log("Std for returns: {}".format(self.std_ret))
[docs] def get_next_batch(self, split=None):
"""
Get the batch from the dataset.
:param split: (str) the type of data split (can be None, 'train', 'val')
:return: (np.ndarray, np.ndarray) inputs and labels
"""
dataloader = {
None: self.dataloader,
'train': self.train_loader,
'val': self.val_loader
}[split]
if dataloader.process is None:
dataloader.start_process()
try:
return next(dataloader)
except StopIteration:
dataloader = iter(dataloader)
return next(dataloader)
[docs] def plot(self):
"""
Show histogram plotting of the episode returns
"""
plt.hist(self.returns)
plt.show()
[docs]class DataLoader(object):
"""
A custom dataloader to preprocessing observations (including images)
and feed them to the network.
Original code for the dataloader from https://github.com/araffin/robotics-rl-srl
(MIT licence)
Authors: Antonin Raffin, René Traoré, Ashley Hill
:param indices: ([int]) list of observations indices
:param observations: (np.ndarray) observations or images path
:param actions: (np.ndarray) actions
:param batch_size: (int) Number of samples per minibatch
:param n_workers: (int) number of preprocessing worker (for loading the images)
:param infinite_loop: (bool) whether to have an iterator that can be resetted
:param max_queue_len: (int) Max number of minibatches that can be preprocessed at the same time
:param shuffle: (bool) Shuffle the minibatch after each epoch
:param start_process: (bool) Start the preprocessing process (default: True)
:param backend: (str) joblib backend (one of 'multiprocessing', 'sequential', 'threading'
or 'loky' in newest versions)
:param sequential: (bool) Do not use subprocess to preprocess the data
(slower but use less memory for the CI)
:param partial_minibatch: (bool) Allow partial minibatches (minibatches with a number of element
lesser than the batch_size)
"""
def __init__(self, indices, observations, actions, batch_size, n_workers=1,
infinite_loop=True, max_queue_len=1, shuffle=False,
start_process=True, backend='threading', sequential=False, partial_minibatch=True):
super(DataLoader, self).__init__()
self.n_workers = n_workers
self.infinite_loop = infinite_loop
self.indices = indices
self.original_indices = indices.copy()
self.n_minibatches = len(indices) // batch_size
# Add a partial minibatch, for instance
# when there is not enough samples
if partial_minibatch and len(indices) / batch_size > 0:
self.n_minibatches += 1
self.batch_size = batch_size
self.observations = observations
self.actions = actions
self.shuffle = shuffle
self.queue = Queue(max_queue_len)
self.process = None
self.load_images = isinstance(observations[0], str)
self.backend = backend
self.sequential = sequential
self.start_idx = 0
if start_process:
self.start_process()
[docs] def start_process(self):
"""Start preprocessing process"""
# Skip if in sequential mode
if self.sequential:
return
self.process = Process(target=self._run)
# Make it a deamon, so it will be deleted at the same time
# of the main process
self.process.daemon = True
self.process.start()
@property
def _minibatch_indices(self):
"""
Current minibatch indices given the current pointer
(start_idx) and the minibatch size
:return: (np.ndarray) 1D array of indices
"""
return self.indices[self.start_idx:self.start_idx + self.batch_size]
[docs] def sequential_next(self):
"""
Sequential version of the pre-processing.
"""
if self.start_idx > len(self.indices):
raise StopIteration
if self.start_idx == 0:
if self.shuffle:
# Shuffle indices
np.random.shuffle(self.indices)
obs = self.observations[self._minibatch_indices]
if self.load_images:
obs = np.concatenate([self._make_batch_element(image_path) for image_path in obs],
axis=0)
actions = self.actions[self._minibatch_indices]
self.start_idx += self.batch_size
return obs, actions
def _run(self):
start = True
with Parallel(n_jobs=self.n_workers, batch_size="auto", backend=self.backend) as parallel:
while start or self.infinite_loop:
start = False
if self.shuffle:
np.random.shuffle(self.indices)
for minibatch_idx in range(self.n_minibatches):
self.start_idx = minibatch_idx * self.batch_size
obs = self.observations[self._minibatch_indices]
if self.load_images:
if self.n_workers <= 1:
obs = [self._make_batch_element(image_path)
for image_path in obs]
else:
obs = parallel(delayed(self._make_batch_element)(image_path)
for image_path in obs)
obs = np.concatenate(obs, axis=0)
actions = self.actions[self._minibatch_indices]
self.queue.put((obs, actions))
# Free memory
del obs
self.queue.put(None)
@classmethod
def _make_batch_element(cls, image_path):
"""
Process one element.
:param image_path: (str) path to an image
:return: (np.ndarray)
"""
# cv2.IMREAD_UNCHANGED is needed to load
# grey and RGBa images
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
# Grey image
if len(image.shape) == 2:
image = image[:, :, np.newaxis]
if image is None:
raise ValueError("Tried to load {}, but it was not found".format(image_path))
# Convert from BGR to RGB
if image.shape[-1] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.reshape((1,) + image.shape)
return image
def __len__(self):
return self.n_minibatches
def __iter__(self):
self.start_idx = 0
self.indices = self.original_indices.copy()
return self
def __next__(self):
if self.sequential:
return self.sequential_next()
if self.process is None:
raise ValueError("You must call .start_process() before using the dataloader")
while True:
try:
val = self.queue.get_nowait()
break
except queue.Empty:
time.sleep(0.001)
continue
if val is None:
raise StopIteration
return val
def __del__(self):
if self.process is not None:
self.process.terminate()