Source code for mushroom_rl.utils.table

import numpy as np

from mushroom_rl.core import Serializable
from mushroom_rl.approximators import Ensemble


[docs]class Table(Serializable): """ Table regressor. Used for discrete state and action spaces. """
[docs] def __init__(self, shape, initial_value=0., dtype=None): """ Constructor. Args: shape (tuple): the shape of the tabular regressor. initial_value (float, 0.): the initial value for each entry of the tabular regressor. dtype ([int, float], None): the dtype of the table array. """ self.table = np.ones(shape, dtype=dtype) * initial_value self._add_save_attr(table='numpy')
def __getitem__(self, args): if self.table.size == 1: return self.table[0] else: idx = tuple([ a[0] if isinstance(a, np.ndarray) else a for a in args]) return self.table[idx] def __setitem__(self, args, value): if self.table.size == 1: self.table[0] = value else: idx = tuple([ a[0] if isinstance(a, np.ndarray) else a for a in args]) self.table[idx] = value
[docs] def fit(self, x, y): """ Args: x (int): index of the table to be filled; y (float): value to fill in the table. """ self[x] = y
[docs] def predict(self, *z): """ Predict the output of the table given an input. Args: *z (list): list of input of the model. If the table is a Q-table, this list may contain states or states and actions depending on whether the call requires to predict all q-values or only one q-value corresponding to the provided action; Returns: The table prediction. """ if z[0].ndim == 1: z = [np.expand_dims(z_i, axis=0) for z_i in z] state = z[0] values = list() if len(z) == 2: action = z[1] for i in range(len(state)): val = self[state[i], action[i]] values.append(val) else: for i in range(len(state)): val = self[state[i], :] values.append(val) if len(values) == 1: return values[0] else: return np.array(values)
@property def n_actions(self): """ Returns: The number of actions considered by the table. """ return self.table.shape[-1] @property def shape(self): """ Returns: The shape of the table. """ return self.table.shape
[docs]class EnsembleTable(Ensemble): """ This class implements functions to manage table ensembles. """
[docs] def __init__(self, n_models, shape, **params): """ Constructor. Args: n_models (int): number of models in the ensemble; shape (np.ndarray): shape of each table in the ensemble. **params: parameters dictionary to create each regressor. """ params['shape'] = shape super(EnsembleTable, self).__init__(Table, n_models, **params)
@property def n_actions(self): return self._model[0].shape[-1]