import torch
import torch.nn as nn
from mushroom_rl.utils.features import uniform_grid
from mushroom_rl.utils.torch import to_float_tensor, to_int_tensor
[docs]class GaussianRBFTensor(nn.Module):
"""
Pytorch module to implement a gaussian radial basis function.
"""
[docs] def __init__(self, mu, scale, dim, use_cuda):
"""
Constructor.
Args:
mu (np.ndarray): centers of the gaussian RBFs;
scale (np.ndarray): scales for the RBFs;
dim (np.ndarray): list of dimension to be considered for the computation of the features;
use_cuda (bool): whether to use cuda for the computation or not.
"""
self._mu = to_float_tensor(mu, use_cuda)
self._scale = to_float_tensor(scale, use_cuda)
if dim is not None:
self._dim = to_int_tensor(dim, use_cuda)
else:
self._dim = None
self._use_cuda = use_cuda
def forward(self, x):
if self._use_cuda:
x = x.cuda()
if self._dim is not None:
x = torch.index_select(x, 1, self._dim)
x = x.unsqueeze(1).repeat(1, self._mu.shape[0], 1)
delta = x - self._mu.repeat(x.shape[0], 1, 1)
return torch.exp(-torch.sum(delta**2 / self._scale, -1)).squeeze(-1)
[docs] @staticmethod
def generate(n_centers, low, high, dimensions=None, use_cuda=False):
"""
Factory method that generates the list of dictionaries to build the
tensors representing a set of uniformly spaced Gaussian radial basis
functions with a 25% overlap.
Args:
n_centers (list): list of the number of radial basis functions to be
used for each dimension;
low (np.ndarray): lowest value for each dimension;
high (np.ndarray): highest value for each dimension;
dimensions (list, None): list of the dimensions of the input to be
considered by the feature. The number of dimensions must match
the number of elements in ``n_centers`` and ``low``;
use_cuda (bool): whether to use cuda for the computation or not.
Returns:
The list of dictionaries as described above.
"""
n_features = len(low)
assert len(n_centers) == n_features
assert len(low) == len(high)
assert dimensions is None or n_features == len(dimensions)
mu, scale = uniform_grid(n_centers, low, high)
tensor_list = [GaussianRBFTensor(mu, scale, dimensions, use_cuda)]
return tensor_list