How to create a regressor

MushroomRL offers a high-level interface to build function regressors. Indeed, it transparently manages regressors for generic functions and Q-function regressors. The user should not care about the low-level implementation of these regressors and should only use the Regressor interface. This interface creates a Q-function regressor or a GenericRegressor depending on whether the n_actions parameter is provided to the constructor or not.

Usage of the Regressor interface

When the action space of RL problems is finite and the adopted approach is value-based,
we want to compute the Q-function of each action. In MushroomRL, this is possible using:
  • a Q-function regressor with a different approximator for each action (ActionRegressor);
  • a single Q-function regressor with a different output for each action (QRegressor).

The QRegressor is suggested when the number of discrete actions is high, due to memory reasons.

The user can create create a QRegressor or an ActionRegressor, setting the output_shape parameter of the Regressor interface. If it is set to (1,), an ActionRegressor is created; otherwise if it is set to the number of discrete actions, a QRegressor is created.


Initially, the MDP, the policy and the features are created:

import numpy as np

from mushroom_rl.algorithms.value import SARSALambdaContinuous
from mushroom_rl.approximators.parametric import LinearApproximator
from mushroom_rl.core import Core
from mushroom_rl.environments import *
from mushroom_rl.features import Features
from mushroom_rl.features.tiles import Tiles
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.utils.callbacks import CollectDataset
from mushroom_rl.utils.parameters import Parameter

mdp = Gym(name='MountainCar-v0', horizon=np.inf, gamma=1.)

# Policy
epsilon = Parameter(value=0.)
pi = EpsGreedy(epsilon=epsilon)

# Q-function approximator
n_tilings = 10
tilings = Tiles.generate(n_tilings, [10, 10],
features = Features(tilings=tilings)

# Agent
learning_rate = Parameter(.1 / n_tilings)

The following snippet, sets the output shape of the regressor to the number of actions, creating a QRegressor:

approximator_params = dict(input_shape=(features.size,),

If you prefer to use an ActionRegressor, simply set the number of actions to (1,):

approximator_params = dict(input_shape=(features.size,),

Then, the rest of the code fits the approximator and runs the evaluation rendering the behaviour of the agent:

agent = SARSALambdaContinuous(, pi, LinearApproximator,
                              lambda_coeff= .9, features=features)

# Algorithm
collect_dataset = CollectDataset()
callbacks = [collect_dataset]
core = Core(agent, mdp, callbacks_fit=callbacks)

# Train
core.learn(n_episodes=100, n_steps_per_fit=1)

# Evaluate
core.evaluate(n_episodes=1, render=True)

Generic regressor

Whenever the n_actions parameter is not provided, the Regressor interface creates a GenericRegressor. This regressor can be used for general purposes and it is more flexible to be used. It is commonly used in policy search algorithms.


Create a dataset of points distributed on a line with random gaussian noise.

import numpy as np
from matplotlib import pyplot as plt

from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import LinearApproximator

x = np.arange(10).reshape(-1, 1)

intercept = 10
noise = np.random.randn(10, 1) * 1
y = 2 * x + intercept + noise

To fit the intercept, polynomial features of degree 1 are created by hand:

phi = np.concatenate((np.ones(10).reshape(-1, 1), x), axis=1)

The regressor is then created and fit (note that n_actions is not provided):

regressor = Regressor(LinearApproximator,
                      output_shape=(1,)), y)

Eventually, the approximated function of the regressor is plotted together with the target points. Moreover, the weights and the gradient in point 5 of the linear approximator are printed.

print('Weights: ' + str(regressor.get_weights()))
print('Gradient: ' + str(regressor.diff(np.array([[5.]]))))

plt.scatter(x, y)
plt.plot(x, regressor.predict(phi))