Source code for mushroom_rl.algorithms.value.batch_td.boosted_fqi

import numpy as np
from tqdm import trange

from .fqi import FQI


[docs]class BoostedFQI(FQI): """ Boosted Fitted Q-Iteration algorithm. "Boosted Fitted Q-Iteration". Tosatto S. et al.. 2017. """
[docs] def __init__(self, mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False): self._prediction = 0. self._next_q = 0. self._idx = 0 assert approximator_params['n_models'] == n_iterations self._add_save_attr( _n_iterations='primitive', _quiet='primitive', _prediction='primitive', _next_q='numpy', _idx='primitive', _target='pickle' ) super().__init__(mdp_info, policy, approximator, n_iterations, approximator_params, fit_params, quiet)
[docs] def fit(self, dataset): state, action, reward, next_state, absorbing, _ = dataset.parse(to='numpy') for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False): if self._target is None: self._target = reward else: self._next_q += self.approximator.predict(next_state, idx=self._idx - 1) if np.any(absorbing): self._next_q *= 1 - absorbing.reshape(-1, 1) max_q = np.max(self._next_q, axis=1) self._target = reward + self.mdp_info.gamma * max_q self._target -= self._prediction self._prediction += self._target self.approximator.fit(state, action, self._target, idx=self._idx, **self._fit_params) self._idx += 1