Source code for mushroom_rl.algorithms.value.td.weighted_q_learning

import numpy as np

from mushroom_rl.algorithms.value.td import TD
from mushroom_rl.utils.table import Table


[docs]class WeightedQLearning(TD): """ Weighted Q-Learning algorithm. "Estimating the Maximum Expected Value through Gaussian Approximation". D'Eramo C. et. al.. 2016. """
[docs] def __init__(self, mdp_info, policy, learning_rate, sampling=True, precision=1000): """ Constructor. Args: sampling (bool, True): use the approximated version to speed up the computation; precision (int, 1000): number of samples to use in the approximated version. """ self.Q = Table(mdp_info.size) self._sampling = sampling self._precision = precision self._add_save_attr( Q='pickle', _sampling='numpy', _precision='numpy', _n_updates='pickle', _sigma='pickle', _Q='pickle', _Q2='pickle', _weights_var='pickle', _w='numpy' ) super().__init__(mdp_info, policy, self.Q, learning_rate) self._n_updates = Table(mdp_info.size) self._sigma = Table(mdp_info.size, initial_value=1e10) self._Q = Table(mdp_info.size) self._Q2 = Table(mdp_info.size) self._weights_var = Table(mdp_info.size)
[docs] def _update(self, state, action, reward, next_state, absorbing): q_current = self.Q[state, action] q_next = self._next_q(next_state) if not absorbing else 0. target = reward + self.mdp_info.gamma * q_next alpha = self.alpha(state, action) self.Q[state, action] = q_current + alpha * (target - q_current) self._n_updates[state, action] += 1 self._Q[state, action] += ( target - self._Q[state, action]) / self._n_updates[state, action] self._Q2[state, action] += (target ** 2. - self._Q2[ state, action]) / self._n_updates[state, action] self._weights_var[state, action] = ( 1 - alpha) ** 2. * self._weights_var[state, action] + alpha ** 2. if self._n_updates[state, action] > 1: var = self._n_updates[state, action] * ( self._Q2[state, action] - self._Q[state, action] ** 2.) / ( self._n_updates[state, action] - 1.) var_estimator = var * self._weights_var[state, action] var_estimator = np.maximum(var_estimator, 1e-10) self._sigma[state, action] = np.sqrt(var_estimator)
[docs] def _next_q(self, next_state): """ Args: next_state (np.ndarray): the state where next action has to be evaluated. Returns: The weighted estimator value in ``next_state``. """ means = self.Q[next_state, :] sigmas = np.zeros(self.Q.shape[-1]) for a in range(sigmas.size): sigmas[a] = self._sigma[next_state, np.array([a])] if self._sampling: samples = np.random.normal(np.repeat([means], self._precision, 0), np.repeat([sigmas], self._precision, 0)) max_idx = np.argmax(samples, axis=1) max_idx, max_count = np.unique(max_idx, return_counts=True) count = np.zeros(means.size) count[max_idx] = max_count self._w = count / self._precision else: raise NotImplementedError return np.dot(self._w, means)