Source code for mushroom_rl.utils.parameters

from mushroom_rl.core import Serializable
from mushroom_rl.utils.table import Table
import numpy as np


def to_parameter(x):
    if isinstance(x, Parameter):
        return x
    else:
        return Parameter(x)


[docs]class Parameter(Serializable): """ This class implements function to manage parameters, such as learning rate. It also allows to have a single parameter for each state of state-action tuple. """
[docs] def __init__(self, value, min_value=None, max_value=None, size=(1,)): """ Constructor. Args: value (float): initial value of the parameter; min_value (float, None): minimum value that the parameter can reach when decreasing; max_value (float, None): maximum value that the parameter can reach when increasing; size (tuple, (1,)): shape of the matrix of parameters; this shape can be used to have a single parameter for each state or state-action tuple. """ self._initial_value = value self._min_value = min_value self._max_value = max_value self._n_updates = Table(size) self._add_save_attr( _initial_value='primitive', _min_value='primitive', _max_value='primitive', _n_updates='mushroom', )
[docs] def __call__(self, *idx, **kwargs): """ Update and return the parameter in the provided index. Args: *idx (list): index of the parameter to return. Returns: The updated parameter in the provided index. """ if self._n_updates.table.size == 1: idx = list() self.update(*idx, **kwargs) return self.get_value(*idx, **kwargs)
[docs] def get_value(self, *idx, **kwargs): """ Return the current value of the parameter in the provided index. Args: *idx (list): index of the parameter to return. Returns: The current value of the parameter in the provided index. """ new_value = self._compute(*idx, **kwargs) if self._min_value is None and self._max_value is None: return new_value else: return np.clip(new_value, self._min_value, self._max_value)
[docs] def _compute(self, *idx, **kwargs): """ Returns: The value of the parameter in the provided index. """ return self._initial_value
[docs] def update(self, *idx, **kwargs): """ Updates the number of visit of the parameter in the provided index. Args: *idx (list): index of the parameter whose number of visits has to be updated. """ self._n_updates[idx] += 1
@property def shape(self): """ Returns: The shape of the table of parameters. """ return self._n_updates.table.shape @property def initial_value(self): """ Returns: The initial value of the parameters. """ return self._initial_value
[docs]class LinearParameter(Parameter): """ This class implements a linearly changing parameter according to the number of times it has been used. """
[docs] def __init__(self, value, threshold_value, n, size=(1,)): self._coeff = (threshold_value - value) / n if self._coeff >= 0: super().__init__(value, None, threshold_value, size) else: super().__init__(value, threshold_value, None, size) self._add_save_attr(_coeff='primitive')
[docs] def _compute(self, *idx, **kwargs): return self._coeff * self._n_updates[idx] + self._initial_value
[docs]class ExponentialParameter(Parameter): """ This class implements a exponentially changing parameter according to the number of times it has been used. """
[docs] def __init__(self, value, exp=1., min_value=None, max_value=None, size=(1,)): self._exp = exp super().__init__(value, min_value, max_value, size) self._add_save_attr(_exp='primitive')
[docs] def _compute(self, *idx, **kwargs): n = np.maximum(self._n_updates[idx], 1) return self._initial_value / n ** self._exp