The features in Mushroom are 1-D arrays computed applying a specified function to a raw input, e.g. polynomial features of the state of an MDP. Mushroom supports three types of features:

  • basis functions;
  • tensor basis functions;
  • tiles.

The GPU-accelerated basis functions are a Pytorch implementation of the standard basis functions. They are less straightforward than the standard ones, but they are faster to compute as they can exploit parallel computing, e.g. GPU-acceleration and multi-core systems.

All the types of features are exposed by a single factory method Features that builds the one requested by the user.

mushroom.features.features.Features(basis_list=None, tilings=None, tensor_list=None, device=None)[source]

Factory method to build the requested type of features. The types are mutually exclusive.

The difference between basis_list and tensor_list is that the former is a list of python classes each one evaluating a single element of the feature vector, while the latter consists in a list of PyTorch modules that can be used to build a PyTorch network. The use of tensor_list is a faster way to compute features than basis_list and is suggested when the computation of the requested features is slow (see the Gaussian radial basis function implementation as an example).

  • basis_list (list, None) – list of basis functions;
  • tilings ([object, list], None) – single object or list of tilings;
  • tensor_list (list, None) – list of dictionaries containing the instructions to build the requested tensors;
  • device (int, None) – where to run the group of tensors. Only needed when using a list of tensors;

The class implementing the requested type of features.

mushroom.features.features.get_action_features(phi_state, action, n_actions)[source]

Compute an array of size len(phi_state) * n_actions filled with zeros, except for elements from len(phi_state) * action to len(phi_state) * (action + 1) that are filled with phi_state. This is used to compute state-action features.

  • phi_state (np.ndarray) – the feature of the state;
  • action (np.ndarray) – the action whose features have to be computed;
  • n_actions (int) – the number of actions.

The state-action features.

The factory method returns a class that extends the abstract class FeatureImplementation.