Value-Based¶
TD¶
-
class
mushroom_rl.algorithms.value.td.
SARSA
(mdp_info, policy, learning_rate)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
SARSA algorithm.
-
__init__
(mdp_info, policy, learning_rate)[source]¶ Constructor.
Parameters: - approximator (object) – the approximator to use to fit the Q-function;
- learning_rate (Parameter) – the learning rate.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
SARSALambda
(mdp_info, policy, learning_rate, lambda_coeff, trace='replacing')[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
The SARSA(lambda) algorithm for finite MDPs.
-
__init__
(mdp_info, policy, learning_rate, lambda_coeff, trace='replacing')[source]¶ Constructor.
Parameters: - lambda_coeff (float) – eligibility trace coefficient;
- trace (str, 'replacing') – type of eligibility trace to use.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
ExpectedSARSA
(mdp_info, policy, learning_rate)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
Expected SARSA algorithm. “A theoretical and empirical analysis of Expected Sarsa”. Seijen H. V. et al.. 2009.
-
__init__
(mdp_info, policy, learning_rate)[source]¶ Constructor.
Parameters: - approximator (object) – the approximator to use to fit the Q-function;
- learning_rate (Parameter) – the learning rate.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
QLearning
(mdp_info, policy, learning_rate)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
Q-Learning algorithm. “Learning from Delayed Rewards”. Watkins C.J.C.H.. 1989.
-
__init__
(mdp_info, policy, learning_rate)[source]¶ Constructor.
Parameters: - approximator (object) – the approximator to use to fit the Q-function;
- learning_rate (Parameter) – the learning rate.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
QLambda
(mdp_info, policy, learning_rate, lambda_coeff, trace='replacing')[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
Q(Lambda) algorithm. “Learning from Delayed Rewards”. Watkins C.J.C.H.. 1989.
-
__init__
(mdp_info, policy, learning_rate, lambda_coeff, trace='replacing')[source]¶ Constructor.
Parameters: - lambda_coeff (float) – eligibility trace coefficient;
- trace (str, 'replacing') – type of eligibility trace to use.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
DoubleQLearning
(mdp_info, policy, learning_rate)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
Double Q-Learning algorithm. “Double Q-Learning”. Hasselt H. V.. 2010.
-
__init__
(mdp_info, policy, learning_rate)[source]¶ Constructor.
Parameters: - approximator (object) – the approximator to use to fit the Q-function;
- learning_rate (Parameter) – the learning rate.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
SpeedyQLearning
(mdp_info, policy, learning_rate)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
Speedy Q-Learning algorithm. “Speedy Q-Learning”. Ghavamzadeh et. al.. 2011.
-
__init__
(mdp_info, policy, learning_rate)[source]¶ Constructor.
Parameters: - approximator (object) – the approximator to use to fit the Q-function;
- learning_rate (Parameter) – the learning rate.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
RLearning
(mdp_info, policy, learning_rate, beta)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
R-Learning algorithm. “A Reinforcement Learning Method for Maximizing Undiscounted Rewards”. Schwartz A.. 1993.
-
__init__
(mdp_info, policy, learning_rate, beta)[source]¶ Constructor.
Parameters: beta (Parameter) – beta coefficient.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
WeightedQLearning
(mdp_info, policy, learning_rate, sampling=True, precision=1000)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
Weighted Q-Learning algorithm. “Estimating the Maximum Expected Value through Gaussian Approximation”. D’Eramo C. et. al.. 2016.
-
__init__
(mdp_info, policy, learning_rate, sampling=True, precision=1000)[source]¶ Constructor.
Parameters: - sampling (bool, True) – use the approximated version to speed up the computation;
- precision (int, 1000) – number of samples to use in the approximated version.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_next_q
(next_state)[source]¶ Parameters: next_state (np.ndarray) – the state where next action has to be evaluated. Returns: The weighted estimator value in next_state
.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
RQLearning
(mdp_info, policy, learning_rate, off_policy=False, beta=None, delta=None)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
RQ-Learning algorithm. “Exploiting Structure and Uncertainty of Bellman Updates in Markov Decision Processes”. Tateo D. et al.. 2017.
-
__init__
(mdp_info, policy, learning_rate, off_policy=False, beta=None, delta=None)[source]¶ Constructor.
Parameters:
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
SARSALambdaContinuous
(mdp_info, policy, approximator, learning_rate, lambda_coeff, features, approximator_params=None)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
Continuous version of SARSA(lambda) algorithm.
-
__init__
(mdp_info, policy, approximator, learning_rate, lambda_coeff, features, approximator_params=None)[source]¶ Constructor.
Parameters: lambda_coeff (float) – eligibility trace coefficient.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.td.
TrueOnlineSARSALambda
(mdp_info, policy, learning_rate, lambda_coeff, features, approximator_params=None)[source]¶ Bases:
mushroom_rl.algorithms.value.td.td.TD
True Online SARSA(lambda) with linear function approximation. “True Online TD(lambda)”. Seijen H. V. et al.. 2014.
-
__init__
(mdp_info, policy, learning_rate, lambda_coeff, features, approximator_params=None)[source]¶ Constructor.
Parameters: lambda_coeff (float) – eligibility trace coefficient.
-
_update
(state, action, reward, next_state, absorbing)[source]¶ Update the Q-table.
Parameters: - state (np.ndarray) – state;
- action (np.ndarray) – action;
- reward (np.ndarray) – reward;
- next_state (np.ndarray) – next state;
- absorbing (np.ndarray) – absorbing flag.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
static
_parse
(dataset)¶ Utility to parse the dataset that is supposed to contain only a sample.
Parameters: dataset (list) – the current episode step. Returns: A tuple containing state, action, reward, next state, absorbing and last flag.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
Batch TD¶
-
class
mushroom_rl.algorithms.value.batch_td.
FQI
(mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False, boosted=False)[source]¶ Bases:
mushroom_rl.algorithms.value.batch_td.batch_td.BatchTD
Fitted Q-Iteration algorithm. “Tree-Based Batch Mode Reinforcement Learning”, Ernst D. et al.. 2005.
-
__init__
(mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False, boosted=False)[source]¶ Constructor.
Parameters: - n_iterations (int) – number of iterations to perform for training;
- quiet (bool, False) – whether to show the progress bar or not;
- boosted (bool, False) – whether to use boosted FQI or not.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.batch_td.
DoubleFQI
(mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False)[source]¶ Bases:
mushroom_rl.algorithms.value.batch_td.fqi.FQI
Double Fitted Q-Iteration algorithm. “Estimating the Maximum Expected Value in Continuous Reinforcement Learning Problems”. D’Eramo C. et al.. 2017.
-
__init__
(mdp_info, policy, approximator, n_iterations, approximator_params=None, fit_params=None, quiet=False)[source]¶ Constructor.
Parameters: - n_iterations (int) – number of iterations to perform for training;
- quiet (bool, False) – whether to show the progress bar or not;
- boosted (bool, False) – whether to use boosted FQI or not.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
_fit_boosted
(x)¶ Single fit iteration for boosted FQI.
Parameters: x (list) – the dataset.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit loop.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.batch_td.
LSPI
(mdp_info, policy, approximator_params=None, epsilon=0.01, fit_params=None, features=None)[source]¶ Bases:
mushroom_rl.algorithms.value.batch_td.batch_td.BatchTD
Least-Squares Policy Iteration algorithm. “Least-Squares Policy Iteration”. Lagoudakis M. G. and Parr R.. 2003.
-
__init__
(mdp_info, policy, approximator_params=None, epsilon=0.01, fit_params=None, features=None)[source]¶ Constructor.
Parameters: epsilon (float, 1e-2) – termination coefficient.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
DQN¶
-
class
mushroom_rl.algorithms.value.dqn.
DQN
(mdp_info, policy, approximator, approximator_params, batch_size, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)[source]¶ Bases:
mushroom_rl.algorithms.agent.Agent
Deep Q-Network algorithm. “Human-Level Control Through Deep Reinforcement Learning”. Mnih V. et al.. 2015.
-
__init__
(mdp_info, policy, approximator, approximator_params, batch_size, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)[source]¶ Constructor.
Parameters: - approximator (object) – the approximator to use to fit the Q-function;
- approximator_params (dict) – parameters of the approximator to build;
- batch_size (int) – the number of samples in a batch;
- target_update_frequency (int) – the number of samples collected between each update of the target network;
- replay_memory ([ReplayMemory, PrioritizedReplayMemory], None) – the object of the replay memory to use; if None, a default replay memory is created;
- initial_replay_size (int) – the number of samples to collect before starting the learning;
- max_replay_size (int) – the maximum number of samples in the replay memory;
- fit_params (dict, None) – parameters of the fitting algorithm of the approximator;
- n_approximators (int, 1) – the number of approximator to use in
AveragedDQN
; - clip_reward (bool, True) – whether to clip the reward or not.
-
_next_q
(next_state, absorbing)[source]¶ Parameters: - next_state (np.ndarray) – the states where next action has to be evaluated;
- absorbing (np.ndarray) – the absorbing flag for the states in
next_state
.
Returns: Maximum action-value for each state in
next_state
.
-
draw_action
(state)[source]¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
_post_load
()[source]¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
copy
()¶ Returns: A deepcopy of the agent.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.dqn.
DoubleDQN
(mdp_info, policy, approximator, approximator_params, batch_size, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)[source]¶ Bases:
mushroom_rl.algorithms.value.dqn.dqn.DQN
Double DQN algorithm. “Deep Reinforcement Learning with Double Q-Learning”. Hasselt H. V. et al.. 2016.
-
_next_q
(next_state, absorbing)[source]¶ Parameters: - next_state (np.ndarray) – the states where next action has to be evaluated;
- absorbing (np.ndarray) – the absorbing flag for the states in
next_state
.
Returns: Maximum action-value for each state in
next_state
.
-
__init__
(mdp_info, policy, approximator, approximator_params, batch_size, target_update_frequency, replay_memory=None, initial_replay_size=500, max_replay_size=5000, fit_params=None, n_approximators=1, clip_reward=True)¶ Constructor.
Parameters: - approximator (object) – the approximator to use to fit the Q-function;
- approximator_params (dict) – parameters of the approximator to build;
- batch_size (int) – the number of samples in a batch;
- target_update_frequency (int) – the number of samples collected between each update of the target network;
- replay_memory ([ReplayMemory, PrioritizedReplayMemory], None) – the object of the replay memory to use; if None, a default replay memory is created;
- initial_replay_size (int) – the number of samples to collect before starting the learning;
- max_replay_size (int) – the maximum number of samples in the replay memory;
- fit_params (dict, None) – parameters of the fitting algorithm of the approximator;
- n_approximators (int, 1) – the number of approximator to use in
AveragedDQN
; - clip_reward (bool, True) – whether to clip the reward or not.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
_update_target
()¶ Update the target network.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.dqn.
AveragedDQN
(mdp_info, policy, approximator, **params)[source]¶ Bases:
mushroom_rl.algorithms.value.dqn.dqn.DQN
Averaged-DQN algorithm. “Averaged-DQN: Variance Reduction and Stabilization for Deep Reinforcement Learning”. Anschel O. et al.. 2017.
-
__init__
(mdp_info, policy, approximator, **params)[source]¶ Constructor.
Parameters: - approximator (object) – the approximator to use to fit the Q-function;
- approximator_params (dict) – parameters of the approximator to build;
- batch_size (int) – the number of samples in a batch;
- target_update_frequency (int) – the number of samples collected between each update of the target network;
- replay_memory ([ReplayMemory, PrioritizedReplayMemory], None) – the object of the replay memory to use; if None, a default replay memory is created;
- initial_replay_size (int) – the number of samples to collect before starting the learning;
- max_replay_size (int) – the maximum number of samples in the replay memory;
- fit_params (dict, None) – parameters of the fitting algorithm of the approximator;
- n_approximators (int, 1) – the number of approximator to use in
AveragedDQN
; - clip_reward (bool, True) – whether to clip the reward or not.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
fit
(dataset)¶ Fit step.
Parameters: dataset (list) – the dataset.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-
-
class
mushroom_rl.algorithms.value.dqn.
CategoricalDQN
(mdp_info, policy, approximator_params, n_atoms, v_min, v_max, **params)[source]¶ Bases:
mushroom_rl.algorithms.value.dqn.dqn.DQN
Categorical DQN algorithm. “A Distributional Perspective on Reinforcement Learning”. Bellemare M. et al.. 2017.
-
__init__
(mdp_info, policy, approximator_params, n_atoms, v_min, v_max, **params)[source]¶ Constructor.
Parameters: - n_atoms (int) – number of atoms;
- v_min (float) – minimum value of value-function;
- v_max (float) – maximum value of value-function.
-
_add_save_attr
(**attr_dict)¶ Add attributes that should be saved for an agent.
Parameters: **attr_dict – dictionary of attributes mapped to the method that should be used to save and load them. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
-
_next_q
(next_state, absorbing)¶ Parameters: - next_state (np.ndarray) – the states where next action has to be evaluated;
- absorbing (np.ndarray) – the absorbing flag for the states in
next_state
.
Returns: Maximum action-value for each state in
next_state
.
-
_post_load
()¶ This method can be overwritten to implement logic that is executed after the loading of the agent.
-
_update_target
()¶ Update the target network.
-
copy
()¶ Returns: A deepcopy of the agent.
-
draw_action
(state)¶ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA).
Parameters: state (np.ndarray) – the state where the agent is. Returns: The action to be executed.
-
episode_start
()¶ Called by the agent when a new episode starts.
-
classmethod
load
(path)¶ Load and deserialize the agent from the given location on disk.
Parameters: path (Path, string) – Relative or absolute path to the agents save location. Returns: The loaded agent.
-
save
(path, full_save=False)¶ Serialize and save the object to the given path on disk.
Parameters: - path (Path, str) – Relative or absolute path to the object save location;
- full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
-
save_zip
(zip_file, full_save, folder='')¶ Serialize and save the agent to the given path on disk.
Parameters: - zip_file (ZipFile) – ZipFile where te object needs to be saved;
- full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
- folder (string, '') – subfolder to be used by the save method.
-
stop
()¶ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency.
-