Source code for mushroom_rl.utils.plot

import numpy as np
import scipy.stats as st


[docs]def get_mean_and_confidence(data): """ Compute the mean and 95% confidence interval Args: data (np.ndarray): Array of experiment data of shape (n_runs, n_epochs). Returns: The mean of the dataset at each epoch along with the confidence interval. """ mean = np.mean(data, axis=0) se = st.sem(data, axis=0) n = len(data) _, interval = st.t.interval(0.95, n-1, scale=se) return mean, interval
[docs]def plot_mean_conf(data, ax, color='blue', line='-', facecolor=None, alpha=0.4, label=None): """ Method to plot mean and confidence interval for data on matplotlib axes. Args: data (np.ndarray): Array of experiment data of shape (n_runs, n_epochs); ax (plt.Axes): matplotlib axes where to create the curve; color (str, 'blue'): matplotlib color identifier for the mean curve; line (str, '-'): matplotlib line type to be used for the mean curve; facecolor (str, None): matplotlib color identifier for the confidence interval; alpha (float, 0.4): transparency of the confidence interval; label (str, one): legend label for the plotted curve. """ facecolor = color if facecolor is None else facecolor mean, conf = get_mean_and_confidence(np.array(data)) upper_bound = mean + conf lower_bound = mean - conf ax.plot(mean, color=color, linestyle=line, label=label) ax.fill_between(np.arange(np.size(mean)), upper_bound, lower_bound, facecolor=facecolor, alpha=alpha)