Source code for mushroom_rl.rl_utils.running_stats

import numpy as np
from collections import deque
from mushroom_rl.core import Serializable, ArrayBackend


[docs]class RunningStandardization(Serializable): """ Compute a running standardization of values according to Welford's online algorithm: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm """
[docs] def __init__(self, shape, backend, alpha=1e-32): """ Constructor. Args: shape (tuple): shape of the data to standardize; backend (str): name of the backend to be used; alpha (float, 1e-32): minimum learning rate. """ assert backend in ["numpy", "torch"] self._shape = shape assert 0. < alpha < 1. self._alpha = alpha self._array_backend = ArrayBackend.get_array_backend(backend) self._n = 1 self._m = self._array_backend.zeros(*self._shape) self._s = self._array_backend.ones(*self._shape) self._add_save_attr( _shape='primitive', _alpha='primitive', _array_backend='pickle', _n='primitive', _m='primitive', _s='primitive' )
[docs] def reset(self): """ Reset the mean and standard deviation. """ self._n = 1 self._m = self._array_backend.zeros(1, *self._shape) self._s = self._array_backend.ones(1, *self._shape)
[docs] def update_stats(self, value): """ Update the statistics with the current data value. Args: value (Array): current data value to use for the update. """ value = self._array_backend.atleast_2d(value) batch_size = len(value) self._n += batch_size alpha = max(batch_size / self._n, self._alpha) new_m = (1 - alpha) * self._m + alpha * value.mean(0) new_s = self._s + (value.mean(0) - self._m) * (value.mean(0) - new_m) self._m, self._s = new_m, new_s
@property def mean(self): """ Returns: The estimated mean value. """ return self._array_backend.squeeze(self._m) @property def std(self): """ Returns: The estimated standard deviation value. """ return self._array_backend.squeeze(self._array_backend.sqrt(self._s / self._n))
[docs]class RunningExpWeightedAverage(Serializable): """ Compute an exponentially weighted moving average. """
[docs] def __init__(self, shape, alpha, backend, init_value=None): """ Constructor. Args: shape (tuple): shape of the data to standardize; alpha (float): learning rate; backend (str): name of the backend to be used; init_value (np.ndarray): initial value of the filter. """ assert backend in ["numpy", "torch"] self._shape = shape self._alpha = alpha self._array_backend = ArrayBackend.get_array_backend(backend) self.reset(init_value) self._add_save_attr( _shape='primitive', _alpha='primitive', _array_backend="pickle", _avg_value='primitive', )
[docs] def reset(self, init_value=None): """ Reset the mean and standard deviation. Args: init_value (Array): initial value of the filter. """ if init_value is None: self._avg_value = self._array_backend.zeros(1, *self._shape) else: self._avg_value = self._array_backend.atleast_2d(self._array_backend.convert(init_value))
[docs] def update_stats(self, value): """ Update the statistics with the current data value. Args: value (Array): current data value to use for the update. """ value = self._array_backend.atleast_2d(value) batch_size = len(value) for i in range(batch_size): self._avg_value = (1. - self._alpha) * self._avg_value + self._alpha * value[i]
@property def mean(self): """ Returns: The estimated mean value. """ return self._avg_value
[docs]class RunningAveragedWindow(Serializable): """ Compute the running average using a window of fixed size. """
[docs] def __init__(self, shape, window_size, backend, init_value=None): """ Constructor. Args: shape (tuple): shape of the data to standardize; window_size (int): size of the windows; backend (str): name of the backend to be used; init_value (np.ndarray): initial value of the filter. """ assert backend in ["numpy", "torch"] self._shape = shape self._window_size = window_size self._array_backend = ArrayBackend.get_array_backend(backend) self.reset(init_value) self._add_save_attr( _shape='primitive', _window_size='primitive', _array_backend='pickle', _avg_buffer='primitive', )
[docs] def reset(self, init_value=None): """ Reset the window. Args: init_value (np.ndarray): initial value of the filter. """ if init_value is None: self._avg_buffer = deque(self._array_backend.zeros(1, *self._shape), maxlen=self._window_size) else: self._avg_buffer = deque([self._array_backend.convert(init_value)], maxlen=self._window_size)
[docs] def update_stats(self, value): """ Update the statistics with the current data value. Args: value (np.ndarray): current data value to use for the update. """ value = self._array_backend.atleast_2d(value) batch_size = len(value) for i in range(batch_size): self._avg_buffer.append(value[i])
@property def mean(self): """ Returns: The estimated mean value. """ return self._array_backend.convert(self._avg_buffer).mean(0)