Source code for mushroom_rl.features.tensors.constant_tensor
import torch
import torch.nn as nn
from mushroom_rl.utils.torch import TorchUtils
[docs]class ConstantTensor(nn.Module):
"""
Pytorch module to implement a constant function (always one).
"""
def forward(self, x):
return torch.ones(x.shape[0], 1).to(TorchUtils.get_device())
@property
def size(self):
return 1