from mushroom_rl.utils.callbacks.callback import Callback
import numpy as np
[docs]class CollectParameters(Callback):
"""
This callback can be used to collect the values of a parameter
(e.g. learning rate) during a run of the agent.
"""
[docs] def __init__(self, parameter, *idx):
"""
Constructor.
Args:
parameter (Parameter): the parameter whose values have to be
collected;
*idx (list): index of the parameter when the ``parameter`` is
tabular.
"""
self._parameter = parameter
self._idx = idx
super().__init__()
[docs] def __call__(self, dataset):
value = self._parameter.get_value(*self._idx)
if isinstance(value, np.ndarray):
value = np.array(value)
self._data_list.append(value)