Source code for mushroom_rl.policy.torch_policy

import numpy as np

import torch
import torch.nn as nn

from mushroom_rl.policy import Policy
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import TorchUtils, CategoricalWrapper
from mushroom_rl.rl_utils.parameters import to_parameter

from itertools import chain


[docs]class TorchPolicy(Policy): """ Interface for a generic PyTorch policy. A PyTorch policy is a policy implemented as a neural network using PyTorch. Functions ending with '_t' use tensors as input, and also as output when required. """ # TODO: remove TorchUtils.to_float_tensor(array) and update the docstring to replace np.ndarray.
[docs] def __init__(self, policy_state_shape=None): """ Constructor. """ super().__init__(policy_state_shape)
[docs] def __call__(self, state, action, policy_state=None): s = TorchUtils.to_float_tensor(torch.atleast_2d(state)) a = TorchUtils.to_float_tensor(torch.atleast_2d(action)) return torch.exp(self.log_prob_t(s, a))
[docs] def draw_action(self, state, policy_state=None): with torch.no_grad(): s = TorchUtils.to_float_tensor(torch.atleast_2d(state)) a = self.draw_action_t(s) return torch.squeeze(a, dim=0).detach(), None
[docs] def distribution(self, state): """ Compute the policy distribution in the given states. Args: state (np.ndarray): the set of states where the distribution is computed. Returns: The torch distribution for the provided states. """ s = TorchUtils.to_float_tensor(state) return self.distribution_t(s)
[docs] def entropy(self, state=None): """ Compute the entropy of the policy. Args: state (np.ndarray, None): the set of states to consider. If the entropy of the policy can be computed in closed form, then ``state`` can be None. Returns: The value of the entropy of the policy. """ s = TorchUtils.to_float_tensor(state) if state is not None else None return self.entropy_t(s).detach()
[docs] def draw_action_t(self, state): """ Draw an action given a tensor. Args: state (torch.Tensor): set of states. Returns: The tensor of the actions to perform in each state. """ raise NotImplementedError
[docs] def log_prob_t(self, state, action): """ Compute the logarithm of the probability of taking ``action`` in ``state``. Args: state (torch.Tensor): set of states. action (torch.Tensor): set of actions. Returns: The tensor of log-probability. """ raise NotImplementedError
[docs] def entropy_t(self, state): """ Compute the entropy of the policy. Args: state (torch.Tensor): the set of states to consider. If the entropy of the policy can be computed in closed form, then ``state`` can be None. Returns: The tensor value of the entropy of the policy. """ raise NotImplementedError
[docs] def distribution_t(self, state): """ Compute the policy distribution in the given states. Args: state (torch.Tensor): the set of states where the distribution is computed. Returns: The torch distribution for the provided states. """ raise NotImplementedError
[docs] def set_weights(self, weights): """ Setter. Args: weights (np.ndarray): the vector of the new weights to be used by the policy. """ raise NotImplementedError
[docs] def get_weights(self): """ Getter. Returns: The current policy weights. """ raise NotImplementedError
[docs] def parameters(self): """ Returns the trainable policy parameters, as expected by torch optimizers. Returns: List of parameters to be optimized. """ raise NotImplementedError
[docs]class GaussianTorchPolicy(TorchPolicy): """ Torch policy implementing a Gaussian policy with trainable standard deviation. The standard deviation is not state-dependent. """
[docs] def __init__(self, network, input_shape, output_shape, std_0=1., policy_state_shape=None, **params): """ Constructor. Args: network (object): the network class used to implement the mean regressor; input_shape (tuple): the shape of the state space; output_shape (tuple): the shape of the action space; std_0 (float, 1.): initial standard deviation; params (dict): parameters used by the network constructor. """ super().__init__(policy_state_shape) self._action_dim = output_shape[0] self._mu = Regressor(TorchApproximator, input_shape, output_shape, network=network, **params) self._predict_params = dict() log_sigma_init = torch.ones(self._action_dim, device=TorchUtils.get_device()) * torch.log(TorchUtils.to_float_tensor(std_0)) self._log_sigma = nn.Parameter(log_sigma_init) self._add_save_attr( _action_dim='primitive', _mu='mushroom', _predict_params='pickle', _log_sigma='torch' )
[docs] def draw_action_t(self, state): return self.distribution_t(state).sample().detach()
[docs] def log_prob_t(self, state, action): return self.distribution_t(state).log_prob(action)[:, None]
[docs] def entropy_t(self, state=None): return self._action_dim / 2 * torch.log(TorchUtils.to_float_tensor(2 * np.pi * np.e))\ + torch.sum(self._log_sigma)
[docs] def distribution_t(self, state): mu, chol_sigma = self.get_mean_and_chol(state) return torch.distributions.MultivariateNormal(loc=mu, scale_tril=chol_sigma, validate_args=False)
def get_mean_and_chol(self, state): assert torch.all(torch.exp(self._log_sigma) > 0) return self._mu(state, **self._predict_params), torch.diag(torch.exp(self._log_sigma))
[docs] def set_weights(self, weights): log_sigma_data = TorchUtils.to_float_tensor(weights[-self._action_dim:]) self._log_sigma.data = log_sigma_data self._mu.set_weights(weights[:-self._action_dim])
[docs] def get_weights(self): mu_weights = self._mu.get_weights() sigma_weights = self._log_sigma.data.detach() return torch.concatenate([mu_weights, sigma_weights])
[docs] def parameters(self): return chain(self._mu.model.network.parameters(), [self._log_sigma])
[docs]class BoltzmannTorchPolicy(TorchPolicy): """ Torch policy implementing a Boltzmann policy. """
[docs] def __init__(self, network, input_shape, output_shape, beta, policy_state_shape=None, **params): """ Constructor. Args: network (object): the network class used to implement the mean regressor; input_shape (tuple): the shape of the state space; output_shape (tuple): the shape of the action space; beta ([float, Parameter]): the inverse of the temperature distribution. As the temperature approaches infinity, the policy becomes more and more random. As the temperature approaches 0.0, the policy becomes more and more greedy. **params: parameters used by the network constructor. """ super().__init__(policy_state_shape) self._action_dim = output_shape[0] self._predict_params = dict() self._logits = Regressor(TorchApproximator, input_shape, output_shape, network=network, **params) self._beta = to_parameter(beta) self._add_save_attr( _action_dim='primitive', _predict_params='pickle', _beta='mushroom', _logits='mushroom' )
[docs] def draw_action_t(self, state): action = self.distribution_t(state).sample().detach() if len(action.shape) > 1: return action else: return action.unsqueeze(0)
[docs] def log_prob_t(self, state, action): return self.distribution_t(state).log_prob(action)[:, None]
[docs] def entropy_t(self, state): return torch.mean(self.distribution_t(state).entropy())
[docs] def distribution_t(self, state): logits = self._logits(state, **self._predict_params) * self._beta(state.numpy()) return CategoricalWrapper(logits)
[docs] def set_weights(self, weights): self._logits.set_weights(weights)
[docs] def get_weights(self): return self._logits.get_weights()
[docs] def parameters(self): return self._logits.model.network.parameters()
def set_beta(self, beta): self._beta = to_parameter(beta)