Source code for mushroom_rl.utils.callbacks.collect_q

import numpy as np
from copy import deepcopy

from mushroom_rl.utils.callbacks.callback import Callback
from mushroom_rl.utils.table import EnsembleTable


[docs]class CollectQ(Callback): """ This callback can be used to collect the action values in all states at the current time step. """
[docs] def __init__(self, approximator): """ Constructor. Args: approximator ([Table, EnsembleTable]): the approximator to use to predict the action values. """ self._approximator = approximator super().__init__()
[docs] def __call__(self, dataset): if isinstance(self._approximator, EnsembleTable): qs = list() for m in self._approximator.model: qs.append(m.table) self._data_list.append(deepcopy(np.mean(qs, 0))) else: self._data_list.append(deepcopy(self._approximator.table))