Source code for mushroom_rl.algorithms.policy_search.black_box_optimization.black_box_optimization

import numpy as np

from mushroom_rl.algorithms.agent import Agent
from mushroom_rl.utils.dataset import compute_J


class BlackBoxOptimization(Agent):
    """
    Base class for black box optimization algorithms.
    These algorithms work on a distribution of policy parameters and often they
    do not rely on stochastic and differentiable policies.

    """
    def __init__(self, mdp_info, distribution, policy, features=None):
        """
        Constructor.

        Args:
            distribution (Distribution): the distribution of policy parameters;
            policy (ParametricPolicy): the policy to use.

        """
        self.distribution = distribution
        self._theta_list = list()

        self._add_save_attr(distribution='mushroom', _theta_list='pickle')

        super().__init__(mdp_info, policy, features)

    def episode_start(self):
        theta = self.distribution.sample()
        self._theta_list.append(theta)
        self.policy.set_weights(theta)

        super().episode_start()

    def fit(self, dataset):
        Jep = compute_J(dataset, self.mdp_info.gamma)

        Jep = np.array(Jep)
        theta = np.array(self._theta_list)

        self._update(Jep, theta)

        self._theta_list = list()

    def stop(self):
        self._theta_list = list()

    def _update(self, Jep, theta):
        """
        Function that implements the update routine of distribution parameters.
        Every black box algorithms should implement this function with the
        proper update.

        Args:
            Jep (np.ndarray): a vector containing the J of the considered
                trajectories;
            theta (np.ndarray): a matrix of policy parameters of the considered
                trajectories.

        """
        raise NotImplementedError('BlackBoxOptimization is an abstract class')