Source code for mushroom_rl.algorithms.value.dqn.dqn

import numpy as np

from mushroom_rl.algorithms.value.dqn import AbstractDQN


[docs]class DQN(AbstractDQN): """ Deep Q-Network algorithm. "Human-Level Control Through Deep Reinforcement Learning". Mnih V. et al.. 2015. """
[docs] def _next_q(self, next_state, absorbing): q = self.target_approximator.predict(next_state, **self._predict_params) if np.any(absorbing): q *= 1 - absorbing.reshape(-1, 1) return np.max(q, axis=1)