Source code for mushroom_rl.algorithms.value.dqn.maxmin_dqn

import numpy as np

from mushroom_rl.algorithms.value.dqn import DQN
from mushroom_rl.approximators.regressor import Regressor


[docs]class MaxminDQN(DQN): """ MaxminDQN algorithm. "Maxmin Q-learning: Controlling the Estimation Bias of Q-learning". Lan Q. et al.. 2020. """
[docs] def __init__(self, mdp_info, policy, approximator, n_approximators, **params): """ Constructor. Args: n_approximators (int): the number of approximators in the ensemble. """ assert n_approximators > 1 self._n_approximators = n_approximators super().__init__(mdp_info, policy, approximator, **params)
[docs] def fit(self, dataset): self._fit_params['idx'] = np.random.randint(self._n_approximators) super().fit(dataset)
def _initialize_regressors(self, approximator, apprx_params_train, apprx_params_target): self.approximator = Regressor(approximator, n_models=self._n_approximators, prediction='min', **apprx_params_train) self.target_approximator = Regressor(approximator, n_models=self._n_approximators, prediction='min', **apprx_params_target) self._update_target()
[docs] def _update_target(self): for i in range(len(self.target_approximator)): self.target_approximator[i].set_weights(self.approximator[i].get_weights())