The features in Mushroom are 1-D arrays computed applying a specified function to a raw input, e.g. polynomial features of the state of an MDP. Mushroom supports three types of features:
- basis functions;
- tensor basis functions;
The GPU-accelerated basis functions are a Pytorch implementation of the standard basis functions. They are less straightforward than the standard ones, but they are faster to compute as they can exploit parallel computing, e.g. GPU-acceleration and multi-core systems.
All the types of features are exposed by a single factory method
that builds the one requested by the user.
Features(basis_list=None, tilings=None, tensor_list=None, device=None)¶
Factory method to build the requested type of features. The types are mutually exclusive.
The difference between
tensor_listis that the former is a list of python classes each one evaluating a single element of the feature vector, while the latter consists in a list of PyTorch modules that can be used to build a PyTorch network. The use of
tensor_listis a faster way to compute features than basis_list and is suggested when the computation of the requested features is slow (see the Gaussian radial basis function implementation as an example).
- basis_list (list, None) – list of basis functions;
- tilings ([object, list], None) – single object or list of tilings;
- tensor_list (list, None) – list of dictionaries containing the instructions to build the requested tensors;
- device (int, None) – where to run the group of tensors. Only needed when using a list of tensors;
The class implementing the requested type of features.
get_action_features(phi_state, action, n_actions)¶
Compute an array of size
n_actionsfilled with zeros, except for elements from
action+ 1) that are filled with phi_state. This is used to compute state-action features.
- phi_state (np.ndarray) – the feature of the state;
- action (np.ndarray) – the action whose features have to be computed;
- n_actions (int) – the number of actions.
The state-action features.
The factory method returns a class that extends the abstract class