import torch
import torch.nn as nn
import numpy as np
from mushroom_rl.utils.features import uniform_grid
from mushroom_rl.utils.torch import TorchUtils
[docs]class GenericBasisTensor(nn.Module):
"""
Abstract Pytorch module to implement a generic basis function.
All the basis function generated by this module are
"""
[docs] def __init__(self, mu, scale, dim=None, normalized=False):
"""
Constructor.
Args:
mu (np.ndarray): centers of the gaussian RBFs;
scale (np.ndarray): scales for the RBFs;
dim (np.ndarray, None): list of dimension to be considered for the computation of the features. If None, all
dimension are used to compute the features;
normalized (bool, False): whether the features need to be normalized to sum to one or not;
"""
self._mu = TorchUtils.to_float_tensor(mu)
self._scale = TorchUtils.to_float_tensor(scale)
if dim is not None:
self._dim = TorchUtils.to_int_tensor(dim)
else:
self._dim = None
self._normalized = normalized
super().__init__()
def forward(self, x):
if self._dim is not None:
x = torch.index_select(x, 1, self._dim)
x = x.unsqueeze(1).repeat(1, self._mu.shape[0], 1)
delta = x - self._mu.repeat(x.shape[0], 1, 1)
phi = self._basis_function(delta, self._scale)
if self._normalized:
return self._normalize(phi).squeeze(-1)
else:
return phi.squeeze(-1)
def _basis_function(self, delta, scale):
raise NotImplementedError
[docs] @staticmethod
def _convert_to_scale(w):
"""
Converts width of a basis function to scale
Args:
w (np.ndarray): array of widths of basis function for every dimension
Returns:
The array of scales for each basis function in any given dimension
"""
raise NotImplementedError
@staticmethod
def _normalize(raw_phi):
if len(raw_phi.shape) == 1:
return torch.nan_to_num(raw_phi / torch.sum(raw_phi, -1), 0.)
else:
return torch.nan_to_num(raw_phi / torch.sum(raw_phi, -1).unsqueeze(1))
[docs] @classmethod
def is_cyclic(cls):
"""
Method used to change the basis generation in case of cyclic features.
Returns:
Whether the space we consider is cyclic or not.
"""
return False
[docs] @classmethod
def generate(cls, n_centers, low, high, dimensions=None, eta=0.25, normalized=False):
"""
Factory method that generates the list of dictionaries to build the tensors representing a set of uniformly
spaced radial basis functions with `eta` overlap.
Args:
n_centers (list): list of the number of radial basis functions to be used for each dimension;
low (np.ndarray): lowest value for each dimension;
high (np.ndarray): highest value for each dimension;
dimensions (list, None): list of the dimensions of the input to be considered by the feature. The number of
dimensions must match the number of elements in ``high`` and ``low``;
eta (float, 0.25): percentage of overlap between the features;
normalized (bool, False): whether the features need to be normalized to sum to one or not.
Returns:
The tensor list.
"""
n_features = len(low)
assert len(n_centers) == n_features
assert len(low) == len(high)
assert dimensions is None or n_features == len(dimensions)
mu, w = uniform_grid(n_centers, low, high, eta, cls.is_cyclic())
scale = cls._convert_to_scale(w)
tensor_list = [cls(mu, scale, dimensions, normalized)]
return tensor_list
@property
def size(self):
return self._mu.shape[0]
[docs]class GaussianRBFTensor(GenericBasisTensor):
def _basis_function(self, delta, scale):
return torch.exp(-torch.sum(delta ** 2 / scale, -1))
[docs] @staticmethod
def _convert_to_scale(w):
return 2 * (w/3) ** 2
[docs]class VonMisesBFTensor(GenericBasisTensor):
def _basis_function(self, delta, scale):
return torch.exp(torch.sum(torch.cos(2*np.pi*delta)/scale, -1) - torch.sum(1/scale))
[docs] @classmethod
def is_cyclic(cls):
return True
[docs] @staticmethod
def _convert_to_scale(w):
return w