Source code for stable_baselines.results_plotter

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from stable_baselines.bench.monitor import load_results

# matplotlib.use('TkAgg')  # Can change to 'Agg' for non-interactive mode
plt.rcParams['svg.fonttype'] = 'none'

X_TIMESTEPS = 'timesteps'
X_EPISODES = 'episodes'
X_WALLTIME = 'walltime_hrs'
COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
          'brown', 'orange', 'teal', 'coral', 'lightblue', 'lime', 'lavender', 'turquoise',
          'darkgreen', 'tan', 'salmon', 'gold', 'lightpurple', 'darkred', 'darkblue']

[docs]def rolling_window(array, window): """ apply a rolling window to a np.ndarray :param array: (np.ndarray) the input Array :param window: (int) length of the rolling window :return: (np.ndarray) rolling window on the input array """ shape = array.shape[:-1] + (array.shape[-1] - window + 1, window) strides = array.strides + (array.strides[-1],) return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
[docs]def window_func(var_1, var_2, window, func): """ apply a function to the rolling window of 2 arrays :param var_1: (np.ndarray) variable 1 :param var_2: (np.ndarray) variable 2 :param window: (int) length of the rolling window :param func: (numpy function) function to apply on the rolling window on variable 2 (such as np.mean) :return: (np.ndarray, np.ndarray) the rolling output with applied function """ var_2_window = rolling_window(var_2, window) function_on_var2 = func(var_2_window, axis=-1) return var_1[window - 1:], function_on_var2
[docs]def ts2xy(timesteps, xaxis): """ Decompose a timesteps variable to x ans ys :param timesteps: (Pandas DataFrame) the input data :param xaxis: (str) the axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') :return: (np.ndarray, np.ndarray) the x and y output """ if xaxis == X_TIMESTEPS: x_var = np.cumsum(timesteps.l.values) y_var = timesteps.r.values elif xaxis == X_EPISODES: x_var = np.arange(len(timesteps)) y_var = timesteps.r.values elif xaxis == X_WALLTIME: x_var = timesteps.t.values / 3600. y_var = timesteps.r.values else: raise NotImplementedError return x_var, y_var
[docs]def plot_curves(xy_list, xaxis, title): """ plot the curves :param xy_list: ([(np.ndarray, np.ndarray)]) the x and y coordinates to plot :param xaxis: (str) the axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') :param title: (str) the title of the plot """ plt.figure(figsize=(8, 2)) maxx = max(xy[0][-1] for xy in xy_list) minx = 0 for (i, (x, y)) in enumerate(xy_list): color = COLORS[i] plt.scatter(x, y, s=2) # Do not plot the smoothed curve at all if the timeseries is shorter than window size. if x.shape[0] >= EPISODES_WINDOW: # Compute and plot rolling mean with window of size EPISODE_WINDOW x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) plt.plot(x, y_mean, color=color) plt.xlim(minx, maxx) plt.title(title) plt.xlabel(xaxis) plt.ylabel("Episode Rewards") plt.tight_layout()
[docs]def plot_results(dirs, num_timesteps, xaxis, task_name): """ plot the results :param dirs: ([str]) the save location of the results to plot :param num_timesteps: (int or None) only plot the points below this value :param xaxis: (str) the axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') :param task_name: (str) the title of the task to plot """ tslist = [] for folder in dirs: timesteps = load_results(folder) if num_timesteps is not None: timesteps = timesteps[timesteps.l.cumsum() <= num_timesteps] tslist.append(timesteps) xy_list = [ts2xy(timesteps_item, xaxis) for timesteps_item in tslist] plot_curves(xy_list, xaxis, task_name)
[docs]def main(): """ Example usage in jupyter-notebook .. code-block:: python from stable_baselines import results_plotter %matplotlib inline results_plotter.plot_results(["./log"], 10e6, results_plotter.X_TIMESTEPS, "Breakout") Here ./log is a directory containing the monitor.csv files """ import argparse import os parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--dirs', help='List of log directories', nargs='*', default=['./log']) parser.add_argument('--num_timesteps', type=int, default=int(10e6)) parser.add_argument('--xaxis', help='Varible on X-axis', default=X_TIMESTEPS) parser.add_argument('--task_name', help='Title of plot', default='Breakout') args = parser.parse_args() args.dirs = [os.path.abspath(folder) for folder in args.dirs] plot_results(args.dirs, args.num_timesteps, args.xaxis, args.task_name)
if __name__ == '__main__': main()