Source code for mushroom_rl.features.tensors.basis_tensor

import torch
import torch.nn as nn
import numpy as np

from mushroom_rl.utils.features import uniform_grid
from mushroom_rl.utils.torch import TorchUtils


[docs]class GenericBasisTensor(nn.Module): """ Abstract Pytorch module to implement a generic basis function. All the basis function generated by this module are """
[docs] def __init__(self, mu, scale, dim=None, normalized=False): """ Constructor. Args: mu (np.ndarray): centers of the gaussian RBFs; scale (np.ndarray): scales for the RBFs; dim (np.ndarray, None): list of dimension to be considered for the computation of the features. If None, all dimension are used to compute the features; normalized (bool, False): whether the features need to be normalized to sum to one or not; """ self._mu = TorchUtils.to_float_tensor(mu) self._scale = TorchUtils.to_float_tensor(scale) if dim is not None: self._dim = TorchUtils.to_int_tensor(dim) else: self._dim = None self._normalized = normalized super().__init__()
def forward(self, x): if self._dim is not None: x = torch.index_select(x, 1, self._dim) x = x.unsqueeze(1).repeat(1, self._mu.shape[0], 1) delta = x - self._mu.repeat(x.shape[0], 1, 1) phi = self._basis_function(delta, self._scale) if self._normalized: return self._normalize(phi).squeeze(-1) else: return phi.squeeze(-1) def _basis_function(self, delta, scale): raise NotImplementedError
[docs] @staticmethod def _convert_to_scale(w): """ Converts width of a basis function to scale Args: w (np.ndarray): array of widths of basis function for every dimension Returns: The array of scales for each basis function in any given dimension """ raise NotImplementedError
@staticmethod def _normalize(raw_phi): if len(raw_phi.shape) == 1: return torch.nan_to_num(raw_phi / torch.sum(raw_phi, -1), 0.) else: return torch.nan_to_num(raw_phi / torch.sum(raw_phi, -1).unsqueeze(1))
[docs] @classmethod def is_cyclic(cls): """ Method used to change the basis generation in case of cyclic features. Returns: Whether the space we consider is cyclic or not. """ return False
[docs] @classmethod def generate(cls, n_centers, low, high, dimensions=None, eta=0.25, normalized=False): """ Factory method that generates the list of dictionaries to build the tensors representing a set of uniformly spaced radial basis functions with `eta` overlap. Args: n_centers (list): list of the number of radial basis functions to be used for each dimension; low (np.ndarray): lowest value for each dimension; high (np.ndarray): highest value for each dimension; dimensions (list, None): list of the dimensions of the input to be considered by the feature. The number of dimensions must match the number of elements in ``high`` and ``low``; eta (float, 0.25): percentage of overlap between the features; normalized (bool, False): whether the features need to be normalized to sum to one or not. Returns: The tensor list. """ n_features = len(low) assert len(n_centers) == n_features assert len(low) == len(high) assert dimensions is None or n_features == len(dimensions) mu, w = uniform_grid(n_centers, low, high, eta, cls.is_cyclic()) scale = cls._convert_to_scale(w) tensor_list = [cls(mu, scale, dimensions, normalized)] return tensor_list
@property def size(self): return self._mu.shape[0]
[docs]class GaussianRBFTensor(GenericBasisTensor): def _basis_function(self, delta, scale): return torch.exp(-torch.sum(delta ** 2 / scale, -1))
[docs] @staticmethod def _convert_to_scale(w): return 2 * (w/3) ** 2
[docs]class VonMisesBFTensor(GenericBasisTensor): def _basis_function(self, delta, scale): return torch.exp(torch.sum(torch.cos(2*np.pi*delta)/scale, -1) - torch.sum(1/scale))
[docs] @classmethod def is_cyclic(cls): return True
[docs] @staticmethod def _convert_to_scale(w): return w