Source code for mushroom_rl.core.agent

from mushroom_rl.core.serialization import Serializable


[docs]class Agent(Serializable): """ This class implements the functions to manage the agent (e.g. move the agent following its policy). """
[docs] def __init__(self, mdp_info, policy, features=None): """ Constructor. Args: mdp_info (MDPInfo): information about the MDP; policy (Policy): the policy followed by the agent; features (object, None): features to extract from the state. """ self.mdp_info = mdp_info self.policy = policy self.phi = features self.next_action = None self._preprocessors = list() self._logger = None self._add_save_attr( mdp_info='pickle', policy='mushroom', phi='pickle', next_action='numpy', _preprocessors='mushroom', _logger='none' )
[docs] def fit(self, dataset, **info): """ Fit step. Args: dataset (list): the dataset. """ raise NotImplementedError('Agent is an abstract class')
[docs] def draw_action(self, state): """ Return the action to execute in the given state. It is the action returned by the policy or the action set by the algorithm (e.g. in the case of SARSA). Args: state (np.ndarray): the state where the agent is. Returns: The action to be executed. """ if self.phi is not None: state = self.phi(state) if self.next_action is None: return self.policy.draw_action(state) else: action = self.next_action self.next_action = None return action
[docs] def episode_start(self): """ Called by the agent when a new episode starts. """ self.policy.reset()
[docs] def stop(self): """ Method used to stop an agent. Useful when dealing with real world environments, simulators, or to cleanup environments internals after a core learn/evaluate to enforce consistency. """ pass
[docs] def set_logger(self, logger): """ Setter that can be used to pass a logger to the algorithm Args: logger (Logger): the logger to be used by the algorithm. """ self._logger = logger
[docs] def add_preprocessor(self, preprocessor): """ Add preprocessor to the preprocessor list. The preprocessors are applied in order. Args: preprocessor (object): state preprocessors to be applied to state variables before feeding them to the agent. """ self._preprocessors.append(preprocessor)
@property def preprocessors(self): """ Access to state preprocessors stored in the agent. """ return self._preprocessors