How to create a regressor¶
Mushroom offers a high-level interface to build function regressors. Indeed, it
transparently manages regressors for generic functions and Q-function regressors.
The user should not care about the low-level implementation of these regressors and
should only use the
Regressor interface. This interface creates a Q-function regressor
GenericRegressor depending on whether the
n_actions parameter is provided
to the constructor or not.
Usage of the
- When the action space of RL problems is finite and the adopted approach is value-based,
- we want to compute the Q-function of each action. In Mushroom, this is possible using:
- a Q-function regressor with a different approximator for each action (
- a single Q-function regressor with a different output for each action (
QRegressor is suggested when the number of discrete actions is high, due to
The user can create create a
QRegressor or an
output_shape parameter of the
Regressor interface. If it is set to (1,),
ActionRegressor is created; otherwise if it is set to the number of discrete actions,
QRegressor is created.
Initially, the MDP, the policy and the features are created:
import numpy as np from mushroom.algorithms.value import SARSALambdaContinuous from mushroom.approximators.parametric import LinearApproximator from mushroom.core import Core from mushroom.environments import * from mushroom.features import Features from mushroom.features.tiles import Tiles from mushroom.policy import EpsGreedy from mushroom.utils.callbacks import CollectDataset from mushroom.utils.parameters import Parameter # MDP mdp = Gym(name='MountainCar-v0', horizon=np.inf, gamma=1.) # Policy epsilon = Parameter(value=0.) pi = EpsGreedy(epsilon=epsilon) # Q-function approximator n_tilings = 10 tilings = Tiles.generate(n_tilings, [10, 10], mdp.info.observation_space.low, mdp.info.observation_space.high) features = Features(tilings=tilings) # Agent learning_rate = Parameter(.1 / n_tilings)
The following snippet, sets the output shape of the regressor to the number of
actions, creating a
approximator_params = dict(input_shape=(features.size,), output_shape=(mdp.info.action_space.n,), n_actions=mdp.info.action_space.n)
If you prefer to use an
ActionRegressor, simply set the number of actions to (1,):
approximator_params = dict(input_shape=(features.size,), output_shape=(1,), n_actions=mdp.info.action_space.n)
Then, the rest of the code fits the approximator and runs the evaluation rendering the behaviour of the agent:
agent = SARSALambdaContinuous(LinearApproximator, pi, mdp.info, approximator_params=approximator_params, learning_rate=learning_rate, lambda_coeff= .9, features=features) # Algorithm collect_dataset = CollectDataset() callbacks = [collect_dataset] core = Core(agent, mdp, callbacks=callbacks) # Train core.learn(n_episodes=100, n_steps_per_fit=1) # Evaluate core.evaluate(n_episodes=1, render=True)
n_actions parameter is not provided, the
Regressor interface creates
GenericRegressor. This regressor can be used for general purposes and it is
more flexible to be used. It is commonly used in policy search algorithms.
Create a dataset of points distributed on a line with random gaussian noise.
import numpy as np from matplotlib import pyplot as plt from mushroom.approximators import Regressor from mushroom.approximators.parametric import LinearApproximator x = np.arange(10).reshape(-1, 1) intercept = 10 noise = np.random.randn(10, 1) * 1 y = 2 * x + intercept + noise
To fit the intercept, polynomial features of degree 1 are created by hand:
phi = np.concatenate((np.ones(10).reshape(-1, 1), x), axis=1)
The regressor is then created and fit (note that
n_actions is not provided):
regressor = Regressor(LinearApproximator, input_shape=(2,), output_shape=(1,)) regressor.fit(phi, y)
Eventually, the approximated function of the regressor is plotted together with the target points. Moreover, the weights and the gradient in point 5 of the linear approximator are printed.
print('Weights: ' + str(regressor.get_weights())) print('Gradient: ' + str(regressor.diff(np.array([[5.]])))) plt.scatter(x, y) plt.plot(x, regressor.predict(phi)) plt.show()