Value-Based

TD

class mushroom_rl.algorithms.value.td.SARSA(mdp_info, policy, learning_rate)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

SARSA algorithm.

__init__(mdp_info, policy, learning_rate)[source]

Constructor.

Parameters:
  • approximator (object) – the approximator to use to fit the Q-function;
  • learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.SARSALambda(mdp_info, policy, learning_rate, lambda_coeff, trace='replacing')[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

The SARSA(lambda) algorithm for finite MDPs.

__init__(mdp_info, policy, learning_rate, lambda_coeff, trace='replacing')[source]

Constructor.

Parameters:
  • lambda_coeff (float) – eligibility trace coefficient;
  • trace (str, 'replacing') – type of eligibility trace to use.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
episode_start()[source]

Called by the agent when a new episode starts.

_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.ExpectedSARSA(mdp_info, policy, learning_rate)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

Expected SARSA algorithm. “A theoretical and empirical analysis of Expected Sarsa”. Seijen H. V. et al.. 2009.

__init__(mdp_info, policy, learning_rate)[source]

Constructor.

Parameters:
  • approximator (object) – the approximator to use to fit the Q-function;
  • learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.QLearning(mdp_info, policy, learning_rate)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

Q-Learning algorithm. “Learning from Delayed Rewards”. Watkins C.J.C.H.. 1989.

__init__(mdp_info, policy, learning_rate)[source]

Constructor.

Parameters:
  • approximator (object) – the approximator to use to fit the Q-function;
  • learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.QLambda(mdp_info, policy, learning_rate, lambda_coeff, trace='replacing')[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

Q(Lambda) algorithm. “Learning from Delayed Rewards”. Watkins C.J.C.H.. 1989.

__init__(mdp_info, policy, learning_rate, lambda_coeff, trace='replacing')[source]

Constructor.

Parameters:
  • lambda_coeff (float) – eligibility trace coefficient;
  • trace (str, 'replacing') – type of eligibility trace to use.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
episode_start()[source]

Called by the agent when a new episode starts.

_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.DoubleQLearning(mdp_info, policy, learning_rate)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

Double Q-Learning algorithm. “Double Q-Learning”. Hasselt H. V.. 2010.

__init__(mdp_info, policy, learning_rate)[source]

Constructor.

Parameters:
  • approximator (object) – the approximator to use to fit the Q-function;
  • learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.SpeedyQLearning(mdp_info, policy, learning_rate)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

Speedy Q-Learning algorithm. “Speedy Q-Learning”. Ghavamzadeh et. al.. 2011.

__init__(mdp_info, policy, learning_rate)[source]

Constructor.

Parameters:
  • approximator (object) – the approximator to use to fit the Q-function;
  • learning_rate (Parameter) – the learning rate.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.RLearning(mdp_info, policy, learning_rate, beta)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

R-Learning algorithm. “A Reinforcement Learning Method for Maximizing Undiscounted Rewards”. Schwartz A.. 1993.

__init__(mdp_info, policy, learning_rate, beta)[source]

Constructor.

Parameters:beta (Parameter) – beta coefficient.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.WeightedQLearning(mdp_info, policy, learning_rate, sampling=True, precision=1000)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

Weighted Q-Learning algorithm. “Estimating the Maximum Expected Value through Gaussian Approximation”. D’Eramo C. et. al.. 2016.

__init__(mdp_info, policy, learning_rate, sampling=True, precision=1000)[source]

Constructor.

Parameters:
  • sampling (bool, True) – use the approximated version to speed up the computation;
  • precision (int, 1000) – number of samples to use in the approximated version.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
_next_q(next_state)[source]
Parameters:next_state (np.ndarray) – the state where next action has to be evaluated.
Returns:The weighted estimator value in next_state.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.RQLearning(mdp_info, policy, learning_rate, off_policy=False, beta=None, delta=None)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

RQ-Learning algorithm. “Exploiting Structure and Uncertainty of Bellman Updates in Markov Decision Processes”. Tateo D. et al.. 2017.

__init__(mdp_info, policy, learning_rate, off_policy=False, beta=None, delta=None)[source]

Constructor.

Parameters:
  • off_policy (bool, False) – whether to use the off policy setting or the online one;
  • beta (Parameter, None) – beta coefficient;
  • delta (Parameter, None) – delta coefficient.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

_next_q(next_state)[source]
Parameters:next_state (np.ndarray) – the state where next action has to be evaluated.
Returns:The weighted estimator value in ‘next_state’.
class mushroom_rl.algorithms.value.td.SARSALambdaContinuous(mdp_info, policy, approximator, learning_rate, lambda_coeff, features, approximator_params=None)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

Continuous version of SARSA(lambda) algorithm.

__init__(mdp_info, policy, approximator, learning_rate, lambda_coeff, features, approximator_params=None)[source]

Constructor.

Parameters:lambda_coeff (float) – eligibility trace coefficient.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
episode_start()[source]

Called by the agent when a new episode starts.

_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.td.TrueOnlineSARSALambda(mdp_info, policy, learning_rate, lambda_coeff, features, approximator_params=None)[source]

Bases: mushroom_rl.algorithms.value.td.td.TD

True Online SARSA(lambda) with linear function approximation. “True Online TD(lambda)”. Seijen H. V. et al.. 2014.

__init__(mdp_info, policy, learning_rate, lambda_coeff, features, approximator_params=None)[source]

Constructor.

Parameters:lambda_coeff (float) – eligibility trace coefficient.
_update(state, action, reward, next_state, absorbing)[source]

Update the Q-table.

Parameters:
  • state (np.ndarray) – state;
  • action (np.ndarray) – action;
  • reward (np.ndarray) – reward;
  • next_state (np.ndarray) – next state;
  • absorbing (np.ndarray) – absorbing flag.
episode_start()[source]

Called by the agent when a new episode starts.

_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
static _parse(dataset)

Utility to parse the dataset that is supposed to contain only a sample.

Parameters:dataset (list) – the current episode step.
Returns:A tuple containing state, action, reward, next state, absorbing and last flag.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

Batch TD

class mushroom_rl.algorithms.value.batch_td.FQI(mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False, boosted=False)[source]

Bases: mushroom_rl.algorithms.value.batch_td.batch_td.BatchTD

Fitted Q-Iteration algorithm. “Tree-Based Batch Mode Reinforcement Learning”, Ernst D. et al.. 2005.

__init__(mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False, boosted=False)[source]

Constructor.

Parameters:
  • n_iterations (int) – number of iterations to perform for training;
  • quiet (bool, False) – whether to show the progress bar or not;
  • boosted (bool, False) – whether to use boosted FQI or not.
fit(dataset)[source]

Fit loop.

_fit(x)[source]

Single fit iteration.

Parameters:x (list) – the dataset.
_fit_boosted(x)[source]

Single fit iteration for boosted FQI.

Parameters:x (list) – the dataset.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.batch_td.DoubleFQI(mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False)[source]

Bases: mushroom_rl.algorithms.value.batch_td.fqi.FQI

Double Fitted Q-Iteration algorithm. “Estimating the Maximum Expected Value in Continuous Reinforcement Learning Problems”. D’Eramo C. et al.. 2017.

__init__(mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False)[source]

Constructor.

Parameters:
  • n_iterations (int) – number of iterations to perform for training;
  • quiet (bool, False) – whether to show the progress bar or not;
  • boosted (bool, False) – whether to use boosted FQI or not.
_fit(x)[source]

Single fit iteration.

Parameters:x (list) – the dataset.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
_fit_boosted(x)

Single fit iteration for boosted FQI.

Parameters:x (list) – the dataset.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit loop.

classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.batch_td.LSPI(mdp_info, policy, approximator_params=None, epsilon=0.01, fit_params=None, features=None)[source]

Bases: mushroom_rl.algorithms.value.batch_td.batch_td.BatchTD

Least-Squares Policy Iteration algorithm. “Least-Squares Policy Iteration”. Lagoudakis M. G. and Parr R.. 2003.

__init__(mdp_info, policy, approximator_params=None, epsilon=0.01, fit_params=None, features=None)[source]

Constructor.

Parameters:epsilon (float, 1e-2) – termination coefficient.
fit(dataset)[source]

Fit step.

Parameters:dataset (list) – the dataset.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

DQN

class mushroom_rl.algorithms.value.dqn.DQN(mdp_info, policy, approximator, approximator_params, batch_size, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)[source]

Bases: mushroom_rl.algorithms.agent.Agent

Deep Q-Network algorithm. “Human-Level Control Through Deep Reinforcement Learning”. Mnih V. et al.. 2015.

__init__(mdp_info, policy, approximator, approximator_params, batch_size, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)[source]

Constructor.

Parameters:
  • approximator (object) – the approximator to use to fit the Q-function;
  • approximator_params (dict) – parameters of the approximator to build;
  • batch_size (int) – the number of samples in a batch;
  • target_update_frequency (int) – the number of samples collected between each update of the target network;
  • replay_memory ([ReplayMemory, PrioritizedReplayMemory], None) – the object of the replay memory to use; if None, a default replay memory is created;
  • initial_replay_size (int) – the number of samples to collect before starting the learning;
  • max_replay_size (int) – the maximum number of samples in the replay memory;
  • fit_params (dict, None) – parameters of the fitting algorithm of the approximator;
  • n_approximators (int, 1) – the number of approximator to use in AveragedDQN;
  • clip_reward (bool, True) – whether to clip the reward or not.
fit(dataset)[source]

Fit step.

Parameters:dataset (list) – the dataset.
_update_target()[source]

Update the target network.

_next_q(next_state, absorbing)[source]
Parameters:
  • next_state (np.ndarray) – the states where next action has to be evaluated;
  • absorbing (np.ndarray) – the absorbing flag for the states in next_state.
Returns:

Maximum action-value for each state in next_state.

draw_action(state)[source]

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
_post_load()[source]

This method can be overwritten to implement logic that is executed after the loading of the agent.

_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
copy()
Returns:A deepcopy of the agent.
episode_start()

Called by the agent when a new episode starts.

classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.dqn.DoubleDQN(mdp_info, policy, approximator, approximator_params, batch_size, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)[source]

Bases: mushroom_rl.algorithms.value.dqn.dqn.DQN

Double DQN algorithm. “Deep Reinforcement Learning with Double Q-Learning”. Hasselt H. V. et al.. 2016.

_next_q(next_state, absorbing)[source]
Parameters:
  • next_state (np.ndarray) – the states where next action has to be evaluated;
  • absorbing (np.ndarray) – the absorbing flag for the states in next_state.
Returns:

Maximum action-value for each state in next_state.

__init__(mdp_info, policy, approximator, approximator_params, batch_size, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)

Constructor.

Parameters:
  • approximator (object) – the approximator to use to fit the Q-function;
  • approximator_params (dict) – parameters of the approximator to build;
  • batch_size (int) – the number of samples in a batch;
  • target_update_frequency (int) – the number of samples collected between each update of the target network;
  • replay_memory ([ReplayMemory, PrioritizedReplayMemory], None) – the object of the replay memory to use; if None, a default replay memory is created;
  • initial_replay_size (int) – the number of samples to collect before starting the learning;
  • max_replay_size (int) – the maximum number of samples in the replay memory;
  • fit_params (dict, None) – parameters of the fitting algorithm of the approximator;
  • n_approximators (int, 1) – the number of approximator to use in AveragedDQN;
  • clip_reward (bool, True) – whether to clip the reward or not.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

_update_target()

Update the target network.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

class mushroom_rl.algorithms.value.dqn.AveragedDQN(mdp_info, policy, approximator, **params)[source]

Bases: mushroom_rl.algorithms.value.dqn.dqn.DQN

Averaged-DQN algorithm. “Averaged-DQN: Variance Reduction and Stabilization for Deep Reinforcement Learning”. Anschel O. et al.. 2017.

__init__(mdp_info, policy, approximator, **params)[source]

Constructor.

Parameters:
  • approximator (object) – the approximator to use to fit the Q-function;
  • approximator_params (dict) – parameters of the approximator to build;
  • batch_size (int) – the number of samples in a batch;
  • target_update_frequency (int) – the number of samples collected between each update of the target network;
  • replay_memory ([ReplayMemory, PrioritizedReplayMemory], None) – the object of the replay memory to use; if None, a default replay memory is created;
  • initial_replay_size (int) – the number of samples to collect before starting the learning;
  • max_replay_size (int) – the maximum number of samples in the replay memory;
  • fit_params (dict, None) – parameters of the fitting algorithm of the approximator;
  • n_approximators (int, 1) – the number of approximator to use in AveragedDQN;
  • clip_reward (bool, True) – whether to clip the reward or not.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.

_update_target()[source]

Update the target network.

_next_q(next_state, absorbing)[source]
Parameters:
  • next_state (np.ndarray) – the states where next action has to be evaluated;
  • absorbing (np.ndarray) – the absorbing flag for the states in next_state.
Returns:

Maximum action-value for each state in next_state.

class mushroom_rl.algorithms.value.dqn.CategoricalDQN(mdp_info, policy, approximator_params, n_atoms, v_min, v_max, **params)[source]

Bases: mushroom_rl.algorithms.value.dqn.dqn.DQN

Categorical DQN algorithm. “A Distributional Perspective on Reinforcement Learning”. Bellemare M. et al.. 2017.

__init__(mdp_info, policy, approximator_params, n_atoms, v_min, v_max, **params)[source]

Constructor.

Parameters:
  • n_atoms (int) – number of atoms;
  • v_min (float) – minimum value of value-function;
  • v_max (float) – maximum value of value-function.
_add_save_attr(**attr_dict)

Add attributes that should be saved for an agent.

Parameters:**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
_next_q(next_state, absorbing)
Parameters:
  • next_state (np.ndarray) – the states where next action has to be evaluated;
  • absorbing (np.ndarray) – the absorbing flag for the states in next_state.
Returns:

Maximum action-value for each state in next_state.

_post_load()

This method can be overwritten to implement logic that is executed after the loading of the agent.

_update_target()

Update the target network.

copy()
Returns:A deepcopy of the agent.
draw_action(state)

Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).

Parameters:state (np.ndarray) – the state where the agent is.
Returns:The action to be executed.
episode_start()

Called by the agent when a new episode starts.

fit(dataset)[source]

Fit step.

Parameters:dataset (list) – the dataset.
classmethod load(path)

Load and deserialize the agent from the given location on disk.

Parameters:path (Path, string) – Relative or absolute path to the agents save location.
Returns:The loaded agent.
save(path, full_save=False)

Serialize and save the object to the given path on disk.

Parameters:
  • path (Path, str) – Relative or absolute path to the object save location;
  • full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
save_zip(zip_file, full_save, folder='')

Serialize and save the agent to the given path on disk.

Parameters:
  • zip_file (ZipFile) – ZipFile where te object needs to be saved;
  • full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
  • folder (string, '') – subfolder to be used by the save method.
stop()

Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.