Source code for mushroom_rl.utils.torch

import torch
import numpy as np


class TorchUtils(object):
    _default_device = 'cpu'

    @classmethod
    def set_default_device(cls, device):
        cls._default_device = device

    @classmethod
    def get_device(cls, device=None):
        return cls._default_device if device is None else device

    @classmethod
    def set_weights(cls, parameters, weights, device=None):
        """
        Function used to set the value of a set of torch parameters given a
        vector of values.

        Args:
            parameters (list): list of parameters to be considered;
            weights (numpy.ndarray): array of the new values for
                the parameters;
            device (str, None): device to use to store the tensor.

        """
        idx = 0
        for p in parameters:
            shape = p.data.shape

            c = 1
            for s in shape:
                c *= s

            w = weights[idx:idx + c].reshape(shape)

            w_tensor = torch.as_tensor(w, device=cls.get_device(device)).type(p.data.dtype)

            p.data = w_tensor
            idx += c

        # assert idx == weights.size # TODO check if you can put another guard here

    @staticmethod
    def get_weights(parameters):
        """
        Function used to get the value of a set of torch parameters as
        a single vector of values.

        Args:
            parameters (list): list of parameters to be considered.

        Returns:
            A numpy vector consisting of all the values of the vectors.

        """
        weights = list()

        for p in parameters:
            w = p.data.detach()
            weights.append(w.flatten())

        weights = torch.concatenate(weights)

        return weights

    @staticmethod
    def zero_grad(parameters):
        """
        Function used to set to zero the value of the gradient of a set
        of torch parameters.

        Args:
            parameters (list): list of parameters to be considered.

        """

        for p in parameters:
            if p.grad is not None:
               p.grad.detach_()
               p.grad.zero_()

    @staticmethod
    def get_gradient(params):
        """
        Function used to get the value of the gradient of a set of
        torch parameters.

        Args:
            parameters (list): list of parameters to be considered.

        """
        views = []
        for p in params:
            if p.grad is None:
                view = p.new(p.numel()).zero_()
            else:
                view = p.grad.view(-1)
            views.append(view)
        return torch.cat(views, 0)

    @classmethod
    def to_float_tensor(cls, x, device=None):
        """
        Function used to convert a numpy array to a float torch tensor.

        Args:
            x (np.ndarray): numpy array to be converted as torch tensor;
            device (str, None): device to use to store the tensor.

        Returns:
            A float tensor build from the values contained in the input array.

        """
        return torch.as_tensor(x, device=cls.get_device(device), dtype=torch.float)

    @classmethod
    def to_int_tensor(cls, x, device=None):
        """
        Function used to convert a numpy array to a float torch tensor.

        Args:
            x (np.ndarray): numpy array to be converted as torch tensor;
            device (str, None): device to use to store the tensor.

        Returns:
            A float tensor build from the values contained in the input array.

        """
        return torch.as_tensor(x, device=cls.get_device(device), dtype=torch.int)

    @staticmethod
    def update_optimizer_parameters(optimizer, new_parameters):
        if len(optimizer.state) > 0:
            for p_old, p_new in zip(optimizer.param_groups[0]['params'], new_parameters):
                data = optimizer.state[p_old]
                del optimizer.state[p_old]
                optimizer.state[p_new] = data

        optimizer.param_groups[0]['params'] = new_parameters


[docs]class CategoricalWrapper(torch.distributions.Categorical): """ Wrapper for the Torch Categorical distribution. Needed to convert a vector of mushroom discrete action in an input with the proper shape of the original distribution implemented in torch """
[docs] def __init__(self, logits): super().__init__(logits=logits)
def log_prob(self, value): return super().log_prob(value.squeeze())