# Value-Based¶

## TD¶

class mushroom.algorithms.value.td.SARSA(policy, mdp_info, learning_rate)[source]

Bases: mushroom.algorithms.value.td.td.TD

SARSA algorithm.

__init__(policy, mdp_info, learning_rate)[source]

Constructor.

Parameters: approximator (object) – the approximator to use to fit the Q-function; learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.td.SARSALambda(policy, mdp_info, learning_rate, lambda_coeff, trace='replacing')[source]

Bases: mushroom.algorithms.value.td.td.TD

The SARSA(lambda) algorithm for finite MDPs.

__init__(policy, mdp_info, learning_rate, lambda_coeff, trace='replacing')[source]

Constructor.

Parameters: lambda_coeff (float) – eligibility trace coefficient; trace (str, 'replacing') – type of eligibility trace to use.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
episode_start()[source]

Called by the agent when a new episode starts.

static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.td.ExpectedSARSA(policy, mdp_info, learning_rate)[source]

Bases: mushroom.algorithms.value.td.td.TD

Expected SARSA algorithm. “A theoretical and empirical analysis of Expected Sarsa”. Seijen H. V. et al.. 2009.

__init__(policy, mdp_info, learning_rate)[source]

Constructor.

Parameters: approximator (object) – the approximator to use to fit the Q-function; learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.td.QLearning(policy, mdp_info, learning_rate)[source]

Bases: mushroom.algorithms.value.td.td.TD

Q-Learning algorithm. “Learning from Delayed Rewards”. Watkins C.J.C.H.. 1989.

__init__(policy, mdp_info, learning_rate)[source]

Constructor.

Parameters: approximator (object) – the approximator to use to fit the Q-function; learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.td.DoubleQLearning(policy, mdp_info, learning_rate)[source]

Bases: mushroom.algorithms.value.td.td.TD

Double Q-Learning algorithm. “Double Q-Learning”. Hasselt H. V.. 2010.

__init__(policy, mdp_info, learning_rate)[source]

Constructor.

Parameters: approximator (object) – the approximator to use to fit the Q-function; learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.td.SpeedyQLearning(policy, mdp_info, learning_rate)[source]

Bases: mushroom.algorithms.value.td.td.TD

Speedy Q-Learning algorithm. “Speedy Q-Learning”. Ghavamzadeh et. al.. 2011.

__init__(policy, mdp_info, learning_rate)[source]

Constructor.

Parameters: approximator (object) – the approximator to use to fit the Q-function; learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.td.RLearning(policy, mdp_info, learning_rate, beta)[source]

Bases: mushroom.algorithms.value.td.td.TD

R-Learning algorithm. “A Reinforcement Learning Method for Maximizing Undiscounted Rewards”. Schwartz A.. 1993.

__init__(policy, mdp_info, learning_rate, beta)[source]

Constructor.

Parameters: beta (Parameter) – beta coefficient.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.td.WeightedQLearning(policy, mdp_info, learning_rate, sampling=True, precision=1000, weighted_policy=False)[source]

Bases: mushroom.algorithms.value.td.td.TD

Weighted Q-Learning algorithm. “Estimating the Maximum Expected Value through Gaussian Approximation”. D’Eramo C. et. al.. 2016.

__init__(policy, mdp_info, learning_rate, sampling=True, precision=1000, weighted_policy=False)[source]

Constructor.

Parameters: sampling (bool, True) – use the approximated version to speed up the computation; precision (int, 1000) – number of samples to use in the approximated version.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
_next_q(next_state)[source]
Parameters: next_state (np.ndarray) – the state where next action has to be evaluated. The weighted estimator value in next_state.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.td.RQLearning(policy, mdp_info, learning_rate, off_policy=False, beta=None, delta=None)[source]

Bases: mushroom.algorithms.value.td.td.TD

RQ-Learning algorithm. “Exploiting Structure and Uncertainty of Bellman Updates in Markov Decision Processes”. Tateo D. et al.. 2017.

__init__(policy, mdp_info, learning_rate, off_policy=False, beta=None, delta=None)[source]

Constructor.

Parameters: off_policy (bool, False) – whether to use the off policy setting or the online one; beta (Parameter, None) – beta coefficient; delta (Parameter, None) – delta coefficient.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

_next_q(next_state)[source]
Parameters: next_state (np.ndarray) – the state where next action has to be evaluated. The weighted estimator value in ‘next_state’.
class mushroom.algorithms.value.td.SARSALambdaContinuous(approximator, policy, mdp_info, learning_rate, lambda_coeff, features, approximator_params=None)[source]

Bases: mushroom.algorithms.value.td.td.TD

Continuous version of SARSA(lambda) algorithm.

__init__(approximator, policy, mdp_info, learning_rate, lambda_coeff, features, approximator_params=None)[source]

Constructor.

Parameters: lambda_coeff (float) – eligibility trace coefficient.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
episode_start()[source]

Called by the agent when a new episode starts.

static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.td.TrueOnlineSARSALambda(policy, mdp_info, learning_rate, lambda_coeff, features, approximator_params=None)[source]

Bases: mushroom.algorithms.value.td.td.TD

True Online SARSA(lambda) with linear function approximation. “True Online TD(lambda)”. Seijen H. V. et al.. 2014.

__init__(policy, mdp_info, learning_rate, lambda_coeff, features, approximator_params=None)[source]

Constructor.

Parameters: lambda_coeff (float) – eligibility trace coefficient.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters: state (np.ndarray) – state; action (np.ndarray) – action; reward (np.ndarray) – reward; next_state (np.ndarray) – next state; absorbing (np.ndarray) – absorbing flag.
episode_start()[source]

Called by the agent when a new episode starts.

static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters: dataset (list) – the current episode step. A tuple containing state, action, reward, next state, absorbing and last flag.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

## Batch TD¶

class mushroom.algorithms.value.batch_td.FQI(approximator, policy, mdp_info, n_iterations, fit_params=None, approximator_params=None, quiet=False, boosted=False)[source]

Bases: mushroom.algorithms.value.batch_td.batch_td.BatchTD

Fitted Q-Iteration algorithm. “Tree-Based Batch Mode Reinforcement Learning”, Ernst D. et al.. 2005.

__init__(approximator, policy, mdp_info, n_iterations, fit_params=None, approximator_params=None, quiet=False, boosted=False)[source]

Constructor.

Parameters: n_iterations (int) – number of iterations to perform for training; quiet (bool, False) – whether to show the progress bar or not; boosted (bool, False) – whether to use boosted FQI or not.
fit(dataset)[source]

Fit loop.

_fit(x)[source]

Single fit iteration.

Parameters: x (list) – the dataset.
_fit_boosted(x)[source]

Single fit iteration for boosted FQI.

Parameters: x (list) – the dataset.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.batch_td.DoubleFQI(approximator, policy, mdp_info, n_iterations, fit_params=None, approximator_params=None, quiet=False)[source]

Bases: mushroom.algorithms.value.batch_td.fqi.FQI

Double Fitted Q-Iteration algorithm. “Estimating the Maximum Expected Value in Continuous Reinforcement Learning Problems”. D’Eramo C. et al.. 2017.

__init__(approximator, policy, mdp_info, n_iterations, fit_params=None, approximator_params=None, quiet=False)[source]

Constructor.

Parameters: n_iterations (int) – number of iterations to perform for training; quiet (bool, False) – whether to show the progress bar or not; boosted (bool, False) – whether to use boosted FQI or not.
_fit(x)[source]

Single fit iteration.

Parameters: x (list) – the dataset.
_fit_boosted(x)

Single fit iteration for boosted FQI.

Parameters: x (list) – the dataset.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit loop.

stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.batch_td.LSPI(policy, mdp_info, epsilon=0.01, fit_params=None, approximator_params=None, features=None)[source]

Bases: mushroom.algorithms.value.batch_td.batch_td.BatchTD

Least-Squares Policy Iteration algorithm. “Least-Squares Policy Iteration”. Lagoudakis M. G. and Parr R.. 2003.

__init__(policy, mdp_info, epsilon=0.01, fit_params=None, approximator_params=None, features=None)[source]

Constructor.

Parameters: epsilon (float, 1e-2) – termination coefficient.
fit(dataset)[source]

Fit step.

Parameters: dataset (list) – the dataset.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

## DQN¶

class mushroom.algorithms.value.dqn.DQN(approximator, policy, mdp_info, batch_size, approximator_params, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)[source]

Deep Q-Network algorithm. “Human-Level Control Through Deep Reinforcement Learning”. Mnih V. et al.. 2015.

__init__(approximator, policy, mdp_info, batch_size, approximator_params, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)[source]

Constructor.

Parameters: approximator (object) – the approximator to use to fit the Q-function; batch_size (int) – the number of samples in a batch; approximator_params (dict) – parameters of the approximator to build; target_update_frequency (int) – the number of samples collected between each update of the target network; replay_memory ([ReplayMemory, PrioritizedReplayMemory], None) – the object of the replay memory to use; if None, a default replay memory is created; initial_replay_size (int) – the number of samples to collect before starting the learning; max_replay_size (int) – the maximum number of samples in the replay memory; fit_params (dict, None) – parameters of the fitting algorithm of the approximator; n_approximators (int, 1) – the number of approximator to use in AverageDQN; clip_reward (bool, True) – whether to clip the reward or not.
fit(dataset)[source]

Fit step.

Parameters: dataset (list) – the dataset.
_update_target()[source]

Update the target network.

_next_q(next_state, absorbing)[source]
Parameters: next_state (np.ndarray) – the states where next action has to be evaluated; absorbing (np.ndarray) – the absorbing flag for the states in next_state. Maximum action-value for each state in next_state.
draw_action(state)[source]

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.dqn.DoubleDQN(approximator, policy, mdp_info, batch_size, approximator_params, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)[source]

Bases: mushroom.algorithms.value.dqn.dqn.DQN

Double DQN algorithm. “Deep Reinforcement Learning with Double Q-Learning”. Hasselt H. V. et al.. 2016.

_next_q(next_state, absorbing)[source]
Parameters: next_state (np.ndarray) – the states where next action has to be evaluated; absorbing (np.ndarray) – the absorbing flag for the states in next_state. Maximum action-value for each state in next_state.
__init__(approximator, policy, mdp_info, batch_size, approximator_params, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)

Constructor.

Parameters: approximator (object) – the approximator to use to fit the Q-function; batch_size (int) – the number of samples in a batch; approximator_params (dict) – parameters of the approximator to build; target_update_frequency (int) – the number of samples collected between each update of the target network; replay_memory ([ReplayMemory, PrioritizedReplayMemory], None) – the object of the replay memory to use; if None, a default replay memory is created; initial_replay_size (int) – the number of samples to collect before starting the learning; max_replay_size (int) – the maximum number of samples in the replay memory; fit_params (dict, None) – parameters of the fitting algorithm of the approximator; n_approximators (int, 1) – the number of approximator to use in AverageDQN; clip_reward (bool, True) – whether to clip the reward or not.
_update_target()

Update the target network.

draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom.algorithms.value.dqn.AveragedDQN(approximator, policy, mdp_info, **params)[source]

Bases: mushroom.algorithms.value.dqn.dqn.DQN

Averaged-DQN algorithm. “Averaged-DQN: Variance Reduction and Stabilization for Deep Reinforcement Learning”. Anschel O. et al.. 2017.

__init__(approximator, policy, mdp_info, **params)[source]

Constructor.

Parameters: approximator (object) – the approximator to use to fit the Q-function; batch_size (int) – the number of samples in a batch; approximator_params (dict) – parameters of the approximator to build; target_update_frequency (int) – the number of samples collected between each update of the target network; replay_memory ([ReplayMemory, PrioritizedReplayMemory], None) – the object of the replay memory to use; if None, a default replay memory is created; initial_replay_size (int) – the number of samples to collect before starting the learning; max_replay_size (int) – the maximum number of samples in the replay memory; fit_params (dict, None) – parameters of the fitting algorithm of the approximator; n_approximators (int, 1) – the number of approximator to use in AverageDQN; clip_reward (bool, True) – whether to clip the reward or not.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

_update_target()[source]

Update the target network.

_next_q(next_state, absorbing)[source]
Parameters: next_state (np.ndarray) – the states where next action has to be evaluated; absorbing (np.ndarray) – the absorbing flag for the states in next_state. Maximum action-value for each state in next_state.
class mushroom.algorithms.value.dqn.CategoricalDQN(policy, mdp_info, n_atoms, v_min, v_max, approximator_params, **params)[source]

Bases: mushroom.algorithms.value.dqn.dqn.DQN

Categorical DQN algorithm. “A Distributional Perspective on Reinforcement Learning”. Bellemare M. et al.. 2017.

__init__(policy, mdp_info, n_atoms, v_min, v_max, approximator_params, **params)[source]

Constructor.

Parameters: n_atoms (int) – number of atoms; v_min (float) – minimum value of value-function; v_max (float) – maximum value of value-function.
_next_q(next_state, absorbing)
Parameters: next_state (np.ndarray) – the states where next action has to be evaluated; absorbing (np.ndarray) – the absorbing flag for the states in next_state. Maximum action-value for each state in next_state.
_update_target()

Update the target network.

draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters: state (np.ndarray) – the state where the agent is. The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)[source]

Fit step.

Parameters: dataset (list) – the dataset.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.