Source code for mushroom_rl.policy.policy

from mushroom_rl.core import Serializable


[docs]class Policy(Serializable): """ Interface representing a generic policy. A policy is a probability distribution that gives the probability of taking an action given a specified state. A policy is used by mushroom agents to interact with the environment. """
[docs] def __call__(self, *args): """ Compute the probability of taking action in a certain state following the policy. Args: *args (list): list containing a state or a state and an action. Returns: The probability of all actions following the policy in the given state if the list contains only the state, else the probability of the given action in the given state following the policy. If the action space is continuous, state and action must be provided """ raise NotImplementedError
[docs] def draw_action(self, state): """ Sample an action in ``state`` using the policy. Args: state (np.ndarray): the state where the agent is. Returns: The action sampled from the policy. """ raise NotImplementedError
[docs] def reset(self): """ Useful when the policy needs a special initialization at the beginning of an episode. """ pass
[docs]class ParametricPolicy(Policy): """ Interface for a generic parametric policy. A parametric policy is a policy that depends on set of parameters, called the policy weights. If the policy is differentiable, the derivative of the probability for a specified state-action pair can be provided. """
[docs] def diff_log(self, state, action): """ Compute the gradient of the logarithm of the probability density function, in the specified state and action pair, i.e.: .. math:: \\nabla_{\\theta}\\log p(s,a) Args: state (np.ndarray): the state where the gradient is computed action (np.ndarray): the action where the gradient is computed Returns: The gradient of the logarithm of the pdf w.r.t. the policy weights """ raise RuntimeError('The policy is not differentiable')
[docs] def diff(self, state, action): """ Compute the derivative of the probability density function, in the specified state and action pair. Normally it is computed w.r.t. the derivative of the logarithm of the probability density function, exploiting the likelihood ratio trick, i.e.: .. math:: \\nabla_{\\theta}p(s,a)=p(s,a)\\nabla_{\\theta}\\log p(s,a) Args: state (np.ndarray): the state where the derivative is computed action (np.ndarray): the action where the derivative is computed Returns: The derivative w.r.t. the policy weights """ return self(state, action) * self.diff_log(state, action)
[docs] def set_weights(self, weights): """ Setter. Args: weights (np.ndarray): the vector of the new weights to be used by the policy. """ raise NotImplementedError
[docs] def get_weights(self): """ Getter. Returns: The current policy weights. """ raise NotImplementedError
@property def weights_size(self): """ Property. Returns: The size of the policy weights. """ raise NotImplementedError