Approximators

MushroomRL exposes the high-level class Regressor that can manage any type of function regressor. This class is a wrapper for any kind of function approximator, e.g. a scikit-learn approximator or a pytorch neural network.

Regressor

class Regressor(approximator, input_shape, output_shape=None, n_actions=None, n_models=None, **params)[source]

Bases: Serializable

This class implements the function to manage a function approximator. This class selects the appropriate kind of regressor to implement according to the parameters provided by the user; this makes this class the only one to use for each kind of task that has to be performed. The inference of the implementation to choose is done checking the provided values of parameters n_actions. If n_actions is provided, it means that the user wants to implement an approximator of the Q-function: if the value of n_actions is equal to the output_shape then a QRegressor is created, else (output_shape should be (1,)) an ActionRegressor is created. Otherwise a GenericRegressor is created. An Ensemble model can be used for all the previous implementations listed before simply providing a n_models parameter greater than 1.

__init__(approximator, input_shape, output_shape=None, n_actions=None, n_models=None, **params)[source]

Constructor.

Parameters:
  • approximator (class) – the approximator class to use to create the model;

  • input_shape (tuple) – the shape of the input of the model;

  • output_shape (tuple, None) – the shape of the output of the model;

  • n_actions (int, None) – number of actions considered to create a QRegressor or an ActionRegressor;

  • n_models (int, 1) – number of models to create;

  • **params – other parameters to create each model.

__call__(*z, **predict_params)[source]

Call self as a function.

fit(*z, **fit_params)[source]

Fit the model.

Parameters:
  • *z – list of input of the model;

  • **fit_params – parameters to use to fit the model.

predict(*z, **predict_params)[source]

Predict the output of the model given an input.

Parameters:
  • *z – list of input of the model;

  • **predict_params – parameters to use to predict with the model.

Returns:

The model prediction.

property model

Returns: The model object.

reset()[source]

Reset the model parameters.

property input_shape

Returns: The shape of the input of the model.

property output_shape

Returns: The shape of the output of the model.

property weights_size

Returns: The shape of the weights of the model.

get_weights()[source]
Returns:

The weights of the model.

set_weights(w)[source]
Parameters:

w (list) – list of weights to be set in the model.

diff(*z)[source]
Parameters:

*z – the input of the model.

Returns:

The derivative of the model.

_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent. For every attribute, it is necessary to specify the method to be used to save and load. Available methods are: numpy, mushroom, torch, json, pickle, primitive and none. The primitive method can be used to store primitive attributes, while the none method always skip the attribute, but ensure that it is initialized to None after the load. The mushroom method can be used with classes that implement the Serializable interface. All the other methods use the library named. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.

Parameters:

**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them.

_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:

A deepcopy of the agent.

classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:

path (Path, string) – Relative or absolute path to the agents save location.

Returns:

The loaded agent.

save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;

  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.

save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;

  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;

  • folder (string, '') – subfolder to be used by the save method.

set_logger(logger, loss_filename=None)[source]

Setter that can be used to pass a logger to the regressor.

Parameters:
  • logger (Logger) – the logger to be used by the regressor;

  • loss_filename (str, None) – optional string to specify the loss filename.

Approximator

Linear

class LinearApproximator(weights=None, input_shape=None, output_shape=(1,), **kwargs)[source]

Bases: Serializable

This class implements a linear approximator.

__init__(weights=None, input_shape=None, output_shape=(1,), **kwargs)[source]

Constructor.

Parameters:
  • weights (np.ndarray) – array of weights to initialize the weights of the approximator;

  • input_shape (np.ndarray, None) – the shape of the input of the model;

  • output_shape (np.ndarray, (1,)) – the shape of the output of the model;

  • **kwargs – other params of the approximator.

fit(x, y, **fit_params)[source]

Fit the model.

Parameters:
  • x (np.ndarray) – input;

  • y (np.ndarray) – target;

  • **fit_params – other parameters used by the fit method of the regressor.

predict(x, **predict_params)[source]

Predict.

Parameters:
  • x (np.ndarray) – input;

  • **predict_params – other parameters used by the predict method the regressor.

Returns:

The predictions of the model.

property weights_size

Returns: The size of the array of weights.

get_weights()[source]

Getter.

Returns:

The set of weights of the approximator.

set_weights(w)[source]

Setter.

Parameters:

w (np.ndarray) – the set of weights to set.

_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent. For every attribute, it is necessary to specify the method to be used to save and load. Available methods are: numpy, mushroom, torch, json, pickle, primitive and none. The primitive method can be used to store primitive attributes, while the none method always skip the attribute, but ensure that it is initialized to None after the load. The mushroom method can be used with classes that implement the Serializable interface. All the other methods use the library named. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.

Parameters:

**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them.

_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:

A deepcopy of the agent.

diff(state, action=None)[source]

Compute the derivative of the output w.r.t. state, and action if provided.

Parameters:
  • state (np.ndarray) – the state;

  • action (np.ndarray, None) – the action.

Returns:

The derivative of the output w.r.t. state, and action if provided.

classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:

path (Path, string) – Relative or absolute path to the agents save location.

Returns:

The loaded agent.

save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;

  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.

save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;

  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;

  • folder (string, '') – subfolder to be used by the save method.

CMAC

class CMAC(tilings, weights=None, output_shape=(1,), **kwargs)[source]

Bases: LinearApproximator

This class implements a Cerebellar Model Arithmetic Computer.

__init__(tilings, weights=None, output_shape=(1,), **kwargs)[source]

Constructor.

Parameters:
  • tilings (list) – list of tilings to discretize the input space.

  • weights (np.ndarray) – array of weights to initialize the weights of the approximator;

  • input_shape (np.ndarray, None) – the shape of the input of the model;

  • output_shape (np.ndarray, (1,)) – the shape of the output of the model;

  • **kwargs – other params of the approximator.

fit(x, y, alpha=1.0, **kwargs)[source]

Fit the model.

Parameters:
  • x (np.ndarray) – input;

  • y (np.ndarray) – target;

  • alpha (float) – learning rate;

  • **kwargs – other parameters used by the fit method of the regressor.

predict(x, **predict_params)[source]

Predict.

Parameters:
  • x (np.ndarray) – input;

  • **predict_params – other parameters used by the predict method the regressor.

Returns:

The predictions of the model.

diff(state, action=None)[source]

Compute the derivative of the output w.r.t. state, and action if provided.

Parameters:
  • state (np.ndarray) – the state;

  • action (np.ndarray, None) – the action.

Returns:

The derivative of the output w.r.t. state, and action if provided.

_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent. For every attribute, it is necessary to specify the method to be used to save and load. Available methods are: numpy, mushroom, torch, json, pickle, primitive and none. The primitive method can be used to store primitive attributes, while the none method always skip the attribute, but ensure that it is initialized to None after the load. The mushroom method can be used with classes that implement the Serializable interface. All the other methods use the library named. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.

Parameters:

**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them.

_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:

A deepcopy of the agent.

get_weights()

Getter.

Returns:

The set of weights of the approximator.

classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:

path (Path, string) – Relative or absolute path to the agents save location.

Returns:

The loaded agent.

save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;

  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.

save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;

  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;

  • folder (string, '') – subfolder to be used by the save method.

set_weights(w)

Setter.

Parameters:

w (np.ndarray) – the set of weights to set.

property weights_size

Returns: The size of the array of weights.

Torch Approximator

class TorchApproximator(input_shape, output_shape, network, optimizer=None, loss=None, batch_size=0, n_fit_targets=1, use_cuda=False, reinitialize=False, dropout=False, quiet=True, **params)[source]

Bases: Serializable

Class to interface a pytorch model to the mushroom Regressor interface. This class implements all is needed to use a generic pytorch model and train it using a specified optimizer and objective function. This class supports also minibatches.

__init__(input_shape, output_shape, network, optimizer=None, loss=None, batch_size=0, n_fit_targets=1, use_cuda=False, reinitialize=False, dropout=False, quiet=True, **params)[source]

Constructor.

Parameters:
  • input_shape (tuple) – shape of the input of the network;

  • output_shape (tuple) – shape of the output of the network;

  • network (torch.nn.Module) – the network class to use;

  • optimizer (dict) – the optimizer used for every fit step;

  • loss (torch.nn.functional) – the loss function to optimize in the fit method;

  • batch_size (int, 0) – the size of each minibatch. If 0, the whole dataset is fed to the optimizer at each epoch;

  • n_fit_targets (int, 1) – the number of fit targets used by the fit method of the network;

  • use_cuda (bool, False) – if True, runs the network on the GPU;

  • reinitialize (bool, False) – if True, the approximator is re initialized at every fit call. To perform the initialization, the weights_init method must be defined properly for the selected model network.

  • dropout (bool, False) – if True, dropout is applied only during train;

  • quiet (bool, True) – if False, shows two progress bars, one for epochs and one for the minibatches;

  • **params – dictionary of parameters needed to construct the network.

predict(*args, output_tensor=False, **kwargs)[source]

Predict.

Parameters:
  • *args – input;

  • output_tensor (bool, False) – whether to return the output as tensor or not;

  • **kwargs – other parameters used by the predict method the regressor.

Returns:

The predictions of the model.

fit(*args, n_epochs=None, weights=None, epsilon=None, patience=1, validation_split=1.0, **kwargs)[source]

Fit the model.

Parameters:
  • *args – input, where the last n_fit_targets elements are considered as the target, while the others are considered as input;

  • n_epochs (int, None) – the number of training epochs;

  • weights (np.ndarray, None) – the weights of each sample in the computation of the loss;

  • epsilon (float, None) – the coefficient used for early stopping;

  • patience (float, 1.) – the number of epochs to wait until stop the learning if not improving;

  • validation_split (float, 1.) – the percentage of the dataset to use as training set;

  • **kwargs – other parameters used by the fit method of the regressor.

set_weights(weights)[source]

Setter.

Parameters:

w (np.ndarray) – the set of weights to set.

get_weights()[source]

Getter.

Returns:

The set of weights of the approximator.

property weights_size

Returns: The size of the array of weights.

diff(*args, **kwargs)[source]

Compute the derivative of the output w.r.t. state, and action if provided.

Parameters:
  • state (np.ndarray) – the state;

  • action (np.ndarray, None) – the action.

Returns:

The derivative of the output w.r.t. state, and action if provided.

property loss_fit

Returns: The average loss of the last epoch of the last fit call.

_post_load()[source]

This method can be overwritten to implement logic that is executed after the loading of the agent.

_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent. For every attribute, it is necessary to specify the method to be used to save and load. Available methods are: numpy, mushroom, torch, json, pickle, primitive and none. The primitive method can be used to store primitive attributes, while the none method always skip the attribute, but ensure that it is initialized to None after the load. The mushroom method can be used with classes that implement the Serializable interface. All the other methods use the library named. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.

Parameters:

**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them.

copy()
Returns:

A deepcopy of the agent.

classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:

path (Path, string) – Relative or absolute path to the agents save location.

Returns:

The loaded agent.

save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;

  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.

save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;

  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;

  • folder (string, '') – subfolder to be used by the save method.