Source code for mushroom_rl.algorithms.value.batch_td.fqi

import numpy as np
from tqdm import trange

from mushroom_rl.algorithms.value.batch_td import BatchTD
from mushroom_rl.utils.dataset import parse_dataset
from mushroom_rl.utils.parameters import to_parameter


[docs]class FQI(BatchTD): """ Fitted Q-Iteration algorithm. "Tree-Based Batch Mode Reinforcement Learning", Ernst D. et al.. 2005. """
[docs] def __init__(self, mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False): """ Constructor. Args: n_iterations ([int, Parameter]): number of iterations to perform for training; quiet (bool, False): whether to show the progress bar or not. """ self._n_iterations = to_parameter(n_iterations) self._quiet = quiet self._target = None self._add_save_attr( _n_iterations='mushroom', _quiet='primitive', _target='pickle' ) super().__init__(mdp_info, policy, approximator, approximator_params, fit_params)
[docs] def fit(self, dataset, **info): state, action, reward, next_state, absorbing, _ = parse_dataset(dataset) for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False): if self._target is None: self._target = reward else: q = self.approximator.predict(next_state) if np.any(absorbing): q *= 1 - absorbing.reshape(-1, 1) max_q = np.max(q, axis=1) self._target = reward + self.mdp_info.gamma * max_q self.approximator.fit(state, action, self._target, **self._fit_params)