import numpy as np
from mushroom_rl.core import Serializable, ArrayBackend
from mushroom_rl.rl_utils.running_stats import RunningStandardization
[docs]class Preprocessor(Serializable):
"""
Abstract preprocessor class.
"""
[docs] def __call__(self, obs):
"""
Preprocess the observations.
Args:
obs (Array): observations to be preprocessed.
Return:
Preprocessed observations.
"""
# TODO: Support vectorized environment and batch preprocessing.
raise NotImplementedError
[docs] def update(self, obs):
"""
Update internal state of the preprocessor using the current observations.
Args:
obs (Array): observations to be preprocessed.
"""
# TODO: Support vectorized environment and batch update.
pass
[docs]class StandardizationPreprocessor(Preprocessor):
"""
Preprocess observations from the environment using a running
standardization.
"""
[docs] def __init__(self, mdp_info, backend, clip_obs=10., alpha=1e-32):
"""
Constructor.
Args:
mdp_info (MDPInfo): information of the MDP;
backend (str): name of the backend to be used;
clip_obs (float, 10.): values to clip the normalized observations;
alpha (float, 1e-32): moving average catchup parameter for the
normalization.
"""
self._clip_obs = clip_obs
self._obs_shape = mdp_info.observation_space.shape
self._array_backend = ArrayBackend.get_array_backend(backend)
self._obs_runstand = RunningStandardization(shape=self._obs_shape,
backend=backend,
alpha=alpha)
self._add_save_attr(
_clip_obs='primitive',
_obs_shape='primitive',
_array_backend='pickle',
_obs_runstand='mushroom'
)
[docs] def __call__(self, obs):
norm_obs = self._array_backend.clip(
(obs - self._obs_runstand.mean) / self._obs_runstand.std,
-self._clip_obs, self._clip_obs
)
return norm_obs
[docs] def update(self, obs):
self._obs_runstand.update_stats(obs)
[docs]class MinMaxPreprocessor(StandardizationPreprocessor):
"""
Preprocess observations from the environment using the bounds of the
observation space of the environment. For observations that are not limited
falls back to using running mean standardization.
"""
[docs] def __init__(self, mdp_info, backend, clip_obs=10., alpha=1e-32):
"""
Constructor.
Args:
mdp_info (MDPInfo): information of the MDP;
backend (str): name of the backend to be used;
clip_obs (float, 10.): values to clip the normalized observations;
alpha (float, 1e-32): moving average catchup parameter for the
normalization.
"""
super(MinMaxPreprocessor, self).__init__(mdp_info, backend, clip_obs, alpha)
obs_low, obs_high = (self._array_backend.convert(mdp_info.observation_space.low.copy(),
mdp_info.observation_space.high.copy()))
self._obs_mask = self._array_backend.where((self._array_backend.abs(obs_low) < 1e20) &
(self._array_backend.abs(obs_high) < 1e20))
self._obs_mask = self._array_backend.concatenate(self._obs_mask)
assert self._obs_mask.sum() > 0, "All observations have unlimited/extremely large range, " \
"you should use StandardizationPreprocessor instead."
self._run_norm_obs = len(self._array_backend.squeeze(self._obs_mask)) != obs_low.shape[0]
self._obs_mean = self._array_backend.zeros_like(obs_low)
self._obs_delta = self._array_backend.ones_like(obs_low)
self._obs_mean[self._obs_mask] = (obs_high[self._obs_mask] + obs_low[self._obs_mask]) / 2.
self._obs_delta[self._obs_mask] = (obs_high[self._obs_mask] - obs_low[self._obs_mask]) / 2.
self._add_save_attr(
_array_backend='pickle',
_run_norm_obs='primitive',
_obs_mask='numpy',
_obs_mean='numpy',
_obs_delta='numpy'
)
[docs] def __call__(self, obs):
orig_obs = self._array_backend.copy(obs)
if self._run_norm_obs:
obs = super(MinMaxPreprocessor, self).__call__(obs)
obs[self._obs_mask] = \
((orig_obs - self._obs_mean) / self._obs_delta)[self._obs_mask]
return obs