MushroomRL¶
Reinforcement Learning python library¶
MushroomRL is a Reinforcement Learning (RL) library that aims to be a simple, yet powerful way to make RL and deep RL experiments. The idea behind Mushroom consists in offering the majority of RL algorithms providing a common interface in order to run them without excessive effort. Moreover, it is designed in such a way that new algorithms and other stuff can generally be added transparently without the need of editing other parts of the code. MushroomRL makes a large use of the environments provided by OpenAI Gym, DeepMind Control Suite and MuJoCo libraries, and the PyTorch library for tensor computation.
With MushroomRL you can:
- solve RL problems simply writing a single small script;
- add custom algorithms and other stuff transparently;
- use all RL environments offered by well-known libraries and build customized environments as well;
- exploit regression models offered by Scikit-Learn or build a customized one with PyTorch;
- run experiments on GPU.
Basic run example¶
Solve a discrete MDP in few a lines. Firstly, create a MDP:
from mushroom_rl.environments import GridWorld
mdp = GridWorld(width=3, height=3, goal=(2, 2), start=(0, 0))
Then, an epsilon-greedy policy with:
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.utils.parameters import Parameter
epsilon = Parameter(value=1.)
policy = EpsGreedy(epsilon=epsilon)
Eventually, the agent is:
from mushroom_rl.algorithms.value import QLearning
learning_rate = Parameter(value=.6)
agent = QLearning(policy, mdp.info, learning_rate)
Learn:
from mushroom_rl.core.core import Core
core = Core(agent, mdp)
core.learn(n_steps=10000, n_steps_per_fit=1)
Print final Q-table:
import numpy as np
shape = agent.approximator.shape
q = np.zeros(shape)
for i in range(shape[0]):
for j in range(shape[1]):
state = np.array([i])
action = np.array([j])
q[i, j] = agent.approximator.predict(state, action)
print(q)
Results in:
[[ 6.561 7.29 6.561 7.29 ]
[ 7.29 8.1 6.561 8.1 ]
[ 8.1 9. 7.29 8.1 ]
[ 6.561 8.1 7.29 8.1 ]
[ 7.29 9. 7.29 9. ]
[ 8.1 10. 8.1 9. ]
[ 7.29 8.1 8.1 9. ]
[ 8.1 9. 8.1 10. ]
[ 0. 0. 0. 0. ]]
where the Q-values of each action of the MDP are stored for each rows representing a state of the MDP.
Download and installation¶
MushroomRL can be downloaded from the GitHub repository. Installation can be done running
pip3 install mushroom_rl
To compile the documentation:
cd mushroom_rl/docs
make html
or to compile the pdf version:
cd mushroom_rl/docs
make latexpdf
To launch MushroomRL test suite:
pytest