Source code for mushroom_rl.algorithms.value.dqn.dqn
import numpy as np
from mushroom_rl.algorithms.value.dqn import AbstractDQN
[docs]class DQN(AbstractDQN):
"""
Deep Q-Network algorithm.
"Human-Level Control Through Deep Reinforcement Learning".
Mnih V. et al.. 2015.
"""
[docs] def _next_q(self, next_state, absorbing):
q = self.target_approximator.predict(next_state, **self._predict_params)
if np.any(absorbing):
q *= 1 - absorbing.reshape(-1, 1)
return np.max(q, axis=1)