import numpy as np
import torch
import torch.nn as nn
from mushroom_rl.policy import Policy
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import TorchUtils, CategoricalWrapper
from mushroom_rl.rl_utils.parameters import to_parameter
from itertools import chain
[docs]class TorchPolicy(Policy):
"""
Interface for a generic PyTorch policy.
A PyTorch policy is a policy implemented as a neural network using PyTorch.
Functions ending with '_t' use tensors as input, and also as output when
required.
"""
# TODO: remove TorchUtils.to_float_tensor(array) and update the docstring to replace np.ndarray.
[docs] def __init__(self, policy_state_shape=None):
"""
Constructor.
"""
super().__init__(policy_state_shape)
[docs] def __call__(self, state, action, policy_state=None):
s = TorchUtils.to_float_tensor(torch.atleast_2d(state))
a = TorchUtils.to_float_tensor(torch.atleast_2d(action))
return torch.exp(self.log_prob_t(s, a))
[docs] def draw_action(self, state, policy_state=None):
with torch.no_grad():
s = TorchUtils.to_float_tensor(torch.atleast_2d(state))
a = self.draw_action_t(s)
return torch.squeeze(a, dim=0).detach(), None
[docs] def distribution(self, state):
"""
Compute the policy distribution in the given states.
Args:
state (np.ndarray): the set of states where the distribution is
computed.
Returns:
The torch distribution for the provided states.
"""
s = TorchUtils.to_float_tensor(state)
return self.distribution_t(s)
[docs] def entropy(self, state=None):
"""
Compute the entropy of the policy.
Args:
state (np.ndarray, None): the set of states to consider. If the
entropy of the policy can be computed in closed form, then
``state`` can be None.
Returns:
The value of the entropy of the policy.
"""
s = TorchUtils.to_float_tensor(state) if state is not None else None
return self.entropy_t(s).detach()
[docs] def draw_action_t(self, state):
"""
Draw an action given a tensor.
Args:
state (torch.Tensor): set of states.
Returns:
The tensor of the actions to perform in each state.
"""
raise NotImplementedError
[docs] def log_prob_t(self, state, action):
"""
Compute the logarithm of the probability of taking ``action`` in
``state``.
Args:
state (torch.Tensor): set of states.
action (torch.Tensor): set of actions.
Returns:
The tensor of log-probability.
"""
raise NotImplementedError
[docs] def entropy_t(self, state):
"""
Compute the entropy of the policy.
Args:
state (torch.Tensor): the set of states to consider. If the
entropy of the policy can be computed in closed form, then
``state`` can be None.
Returns:
The tensor value of the entropy of the policy.
"""
raise NotImplementedError
[docs] def distribution_t(self, state):
"""
Compute the policy distribution in the given states.
Args:
state (torch.Tensor): the set of states where the distribution is
computed.
Returns:
The torch distribution for the provided states.
"""
raise NotImplementedError
[docs] def set_weights(self, weights):
"""
Setter.
Args:
weights (np.ndarray): the vector of the new weights to be used by
the policy.
"""
raise NotImplementedError
[docs] def get_weights(self):
"""
Getter.
Returns:
The current policy weights.
"""
raise NotImplementedError
[docs] def parameters(self):
"""
Returns the trainable policy parameters, as expected by torch
optimizers.
Returns:
List of parameters to be optimized.
"""
raise NotImplementedError
[docs]class GaussianTorchPolicy(TorchPolicy):
"""
Torch policy implementing a Gaussian policy with trainable standard
deviation. The standard deviation is not state-dependent.
"""
[docs] def __init__(self, network, input_shape, output_shape, std_0=1., policy_state_shape=None, **params):
"""
Constructor.
Args:
network (object): the network class used to implement the mean
regressor;
input_shape (tuple): the shape of the state space;
output_shape (tuple): the shape of the action space;
std_0 (float, 1.): initial standard deviation;
params (dict): parameters used by the network constructor.
"""
super().__init__(policy_state_shape)
self._action_dim = output_shape[0]
self._mu = Regressor(TorchApproximator, input_shape, output_shape, network=network, **params)
self._predict_params = dict()
log_sigma_init = torch.ones(self._action_dim, device=TorchUtils.get_device()) * torch.log(TorchUtils.to_float_tensor(std_0))
self._log_sigma = nn.Parameter(log_sigma_init)
self._add_save_attr(
_action_dim='primitive',
_mu='mushroom',
_predict_params='pickle',
_log_sigma='torch'
)
[docs] def draw_action_t(self, state):
return self.distribution_t(state).sample().detach()
[docs] def log_prob_t(self, state, action):
return self.distribution_t(state).log_prob(action)[:, None]
[docs] def entropy_t(self, state=None):
return self._action_dim / 2 * torch.log(TorchUtils.to_float_tensor(2 * np.pi * np.e))\
+ torch.sum(self._log_sigma)
[docs] def distribution_t(self, state):
mu, chol_sigma = self.get_mean_and_chol(state)
return torch.distributions.MultivariateNormal(loc=mu, scale_tril=chol_sigma, validate_args=False)
def get_mean_and_chol(self, state):
assert torch.all(torch.exp(self._log_sigma) > 0)
return self._mu(state, **self._predict_params), torch.diag(torch.exp(self._log_sigma))
[docs] def set_weights(self, weights):
log_sigma_data = TorchUtils.to_float_tensor(weights[-self._action_dim:])
self._log_sigma.data = log_sigma_data
self._mu.set_weights(weights[:-self._action_dim])
[docs] def get_weights(self):
mu_weights = self._mu.get_weights()
sigma_weights = self._log_sigma.data.detach()
return torch.concatenate([mu_weights, sigma_weights])
[docs] def parameters(self):
return chain(self._mu.model.network.parameters(), [self._log_sigma])
[docs]class BoltzmannTorchPolicy(TorchPolicy):
"""
Torch policy implementing a Boltzmann policy.
"""
[docs] def __init__(self, network, input_shape, output_shape, beta, policy_state_shape=None, **params):
"""
Constructor.
Args:
network (object): the network class used to implement the mean
regressor;
input_shape (tuple): the shape of the state space;
output_shape (tuple): the shape of the action space;
beta ([float, Parameter]): the inverse of the temperature distribution. As
the temperature approaches infinity, the policy becomes more and
more random. As the temperature approaches 0.0, the policy becomes
more and more greedy.
**params: parameters used by the network constructor.
"""
super().__init__(policy_state_shape)
self._action_dim = output_shape[0]
self._predict_params = dict()
self._logits = Regressor(TorchApproximator, input_shape, output_shape, network=network, **params)
self._beta = to_parameter(beta)
self._add_save_attr(
_action_dim='primitive',
_predict_params='pickle',
_beta='mushroom',
_logits='mushroom'
)
[docs] def draw_action_t(self, state):
action = self.distribution_t(state).sample().detach()
if len(action.shape) > 1:
return action
else:
return action.unsqueeze(0)
[docs] def log_prob_t(self, state, action):
return self.distribution_t(state).log_prob(action)[:, None]
[docs] def entropy_t(self, state):
return torch.mean(self.distribution_t(state).entropy())
[docs] def distribution_t(self, state):
logits = self._logits(state, **self._predict_params) * self._beta(state.numpy())
return CategoricalWrapper(logits)
[docs] def set_weights(self, weights):
self._logits.set_weights(weights)
[docs] def get_weights(self):
return self._logits.get_weights()
[docs] def parameters(self):
return self._logits.model.network.parameters()
def set_beta(self, beta):
self._beta = to_parameter(beta)