Source code for mushroom_rl.algorithms.policy_search.policy_gradient.policy_gradient

import numpy as np

from mushroom_rl.algorithms.agent import Agent


class PolicyGradient(Agent):
    """
    Abstract class to implement a generic Policy Search algorithm using the
    gradient of the policy to update its parameters.
    "A survey on Policy Gradient algorithms for Robotics". Deisenroth M. P. et
    al.. 2011.

    """
    def __init__(self, mdp_info, policy, learning_rate, features):
        """
        Constructor.

        Args:
            learning_rate (float): the learning rate.

        """
        self.learning_rate = learning_rate
        self.df = 1
        self.J_episode = 0

        self._add_save_attr(
            learning_rate='pickle',
            df='numpy',
            J_episode='numpy'
        )

        super().__init__(mdp_info, policy, features)

    def fit(self, dataset):
        J = list()
        self.df = 1.
        self.J_episode = 0.
        self._init_update()
        for sample in dataset:
            x, u, r, xn, _, last = self._parse(sample)
            self._step_update(x, u, r)
            self.J_episode += self.df * r
            self.df *= self.mdp_info.gamma

            if last:
                self._episode_end_update()
                J.append(self.J_episode)
                self.J_episode = 0.
                self.df = 1.
                self._init_update()

        self._update_parameters(J)

    def _update_parameters(self, J):
        """
        Update the parameters of the policy.

        Args:
             J (list): list of the cumulative discounted rewards for each
                episode in the dataset.

        """
        res = self._compute_gradient(J)

        theta = self.policy.get_weights()

        if len(res) == 1:
            grad = res[0]
            delta = self.learning_rate(grad) * grad
        else:
            grad, nat_grad = res
            delta = self.learning_rate(grad, nat_grad) * nat_grad

        theta_new = theta + delta
        self.policy.set_weights(theta_new)

    def _init_update(self):
        """
        This function is called, when parsing the dataset, at the beginning
        of each episode. The implementation is dependent on the algorithm (e.g.
        REINFORCE resets some data structure).

        """
        raise NotImplementedError('PolicyGradient is an abstract class')

    def _step_update(self, x, u, r):
        """
        This function is called, when parsing the dataset, at each episode step.

        Args:
            x (np.ndarray): the state at the current step;
            u (np.ndarray): the action at the current step;
            r (np.ndarray): the reward at the current step.

        """
        raise NotImplementedError('PolicyGradient is an abstract class')

    def _episode_end_update(self):
        """
        This function is called, when parsing the dataset, at the beginning
        of each episode. The implementation is dependent on the algorithm (e.g.
        REINFORCE updates some data structures).

        """
        raise NotImplementedError('PolicyGradient is an abstract class')

    def _compute_gradient(self, J):
        """
        Return the gradient computed by the algorithm.

        Args:
             J (list): list of the cumulative discounted rewards for each
                episode in the dataset.

        """
        raise NotImplementedError('PolicyGradient is an abstract class')

    def _parse(self, sample):
        """
        Utility to parse the sample.

        Args:
             sample (list): the current episode step.

        Returns:
            A tuple containing state, action, reward, next state, absorbing and
            last flag. If provided, ``state`` is preprocessed with the features.

        """
        state = sample[0]
        action = sample[1]
        reward = sample[2]
        next_state = sample[3]
        absorbing = sample[4]
        last = sample[5]

        if self.phi is not None:
            state = self.phi(state)

        return state, action, reward, next_state, absorbing, last