import numpy as np
from mushroom_rl.algorithms.value.dqn import AbstractDQN
from mushroom_rl.approximators.regressor import Regressor
[docs]class AveragedDQN(AbstractDQN):
"""
Averaged-DQN algorithm.
"Averaged-DQN: Variance Reduction and Stabilization for Deep Reinforcement
Learning". Anschel O. et al.. 2017.
"""
[docs] def __init__(self, mdp_info, policy, approximator, n_approximators,
**params):
"""
Constructor.
Args:
n_approximators (int): the number of target approximators to store.
"""
assert n_approximators > 1
self._n_approximators = n_approximators
super().__init__(mdp_info, policy, approximator, **params)
self._n_fitted_target_models = 1
self._add_save_attr(_n_fitted_target_models='primitive')
def _initialize_regressors(self, approximator, apprx_params_train,
apprx_params_target):
self.approximator = Regressor(approximator, **apprx_params_train)
self.target_approximator = Regressor(approximator,
n_models=self._n_approximators,
**apprx_params_target)
for i in range(len(self.target_approximator)):
self.target_approximator[i].set_weights(
self.approximator.get_weights()
)
[docs] def _update_target(self):
idx = self._n_updates // self._target_update_frequency\
% self._n_approximators
self.target_approximator[idx].set_weights(
self.approximator.get_weights())
if self._n_fitted_target_models < self._n_approximators:
self._n_fitted_target_models += 1
[docs] def _next_q(self, next_state, absorbing):
q = list()
for idx in range(self._n_fitted_target_models):
q_target_idx = self.target_approximator.predict(next_state, idx=idx, **self._predict_params)
q.append(q_target_idx)
q = np.mean(q, axis=0)
if np.any(absorbing):
q *= 1 - absorbing.reshape(-1, 1)
return np.max(q, axis=1)