Source code for mushroom_rl.algorithms.actor_critic.deep_actor_critic.ppo

import numpy as np

import torch
import torch.nn.functional as F

from mushroom_rl.core import Agent
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import to_float_tensor, update_optimizer_parameters
from mushroom_rl.utils.minibatches import minibatch_generator
from mushroom_rl.utils.dataset import parse_dataset, compute_J
from mushroom_rl.utils.value_functions import compute_gae
from mushroom_rl.utils.parameters import to_parameter


[docs]class PPO(Agent): """ Proximal Policy Optimization algorithm. "Proximal Policy Optimization Algorithms". Schulman J. et al.. 2017. """
[docs] def __init__(self, mdp_info, policy, actor_optimizer, critic_params, n_epochs_policy, batch_size, eps_ppo, lam, ent_coeff=0.0, critic_fit_params=None): """ Constructor. Args: policy (TorchPolicy): torch policy to be learned by the algorithm actor_optimizer (dict): parameters to specify the actor optimizer algorithm; critic_params (dict): parameters of the critic approximator to build; n_epochs_policy ([int, Parameter]): number of policy updates for every dataset; batch_size ([int, Parameter]): size of minibatches for every optimization step eps_ppo ([float, Parameter]): value for probability ratio clipping; lam ([float, Parameter], 1.): lambda coefficient used by generalized advantage estimation; ent_coeff ([float, Parameter], 1.): coefficient for the entropy regularization term; critic_fit_params (dict, None): parameters of the fitting algorithm of the critic approximator. """ self._critic_fit_params = dict(n_epochs=10) if critic_fit_params is None else critic_fit_params self._n_epochs_policy = to_parameter(n_epochs_policy) self._batch_size = to_parameter(batch_size) self._eps_ppo = to_parameter(eps_ppo) self._optimizer = actor_optimizer['class'](policy.parameters(), **actor_optimizer['params']) self._lambda = to_parameter(lam) self._ent_coeff = to_parameter(ent_coeff) self._V = Regressor(TorchApproximator, **critic_params) self._iter = 1 self._add_save_attr( _critic_fit_params='pickle', _n_epochs_policy='mushroom', _batch_size='mushroom', _eps_ppo='mushroom', _ent_coeff='mushroom', _optimizer='torch', _lambda='mushroom', _V='mushroom', _iter='primitive' ) super().__init__(mdp_info, policy, None)
[docs] def fit(self, dataset, **info): x, u, r, xn, absorbing, last = parse_dataset(dataset) x = x.astype(np.float32) u = u.astype(np.float32) r = r.astype(np.float32) xn = xn.astype(np.float32) obs = to_float_tensor(x, self.policy.use_cuda) act = to_float_tensor(u, self.policy.use_cuda) v_target, np_adv = compute_gae(self._V, x, xn, r, absorbing, last, self.mdp_info.gamma, self._lambda()) np_adv = (np_adv - np.mean(np_adv)) / (np.std(np_adv) + 1e-8) adv = to_float_tensor(np_adv, self.policy.use_cuda) old_pol_dist = self.policy.distribution_t(obs) old_log_p = old_pol_dist.log_prob(act)[:, None].detach() self._V.fit(x, v_target, **self._critic_fit_params) self._update_policy(obs, act, adv, old_log_p) # Print fit information self._log_info(dataset, x, v_target, old_pol_dist) self._iter += 1
def _update_policy(self, obs, act, adv, old_log_p): for epoch in range(self._n_epochs_policy()): for obs_i, act_i, adv_i, old_log_p_i in minibatch_generator( self._batch_size(), obs, act, adv, old_log_p): self._optimizer.zero_grad() prob_ratio = torch.exp( self.policy.log_prob_t(obs_i, act_i) - old_log_p_i ) clipped_ratio = torch.clamp(prob_ratio, 1 - self._eps_ppo(), 1 + self._eps_ppo.get_value()) loss = -torch.mean(torch.min(prob_ratio * adv_i, clipped_ratio * adv_i)) loss -= self._ent_coeff()*self.policy.entropy_t(obs_i) loss.backward() self._optimizer.step() def _log_info(self, dataset, x, v_target, old_pol_dist): if self._logger: logging_verr = [] torch_v_targets = torch.tensor(v_target, dtype=torch.float) for idx in range(len(self._V)): v_pred = torch.tensor(self._V(x, idx=idx), dtype=torch.float) v_err = F.mse_loss(v_pred, torch_v_targets) logging_verr.append(v_err.item()) logging_ent = self.policy.entropy(x) new_pol_dist = self.policy.distribution(x) logging_kl = torch.mean(torch.distributions.kl.kl_divergence( new_pol_dist, old_pol_dist)) avg_rwd = np.mean(compute_J(dataset)) msg = "Iteration {}:\n\t\t\t\trewards {} vf_loss {}\n\t\t\t\tentropy {} kl {}".format( self._iter, avg_rwd, logging_verr, logging_ent, logging_kl) self._logger.info(msg) self._logger.weak_line()
[docs] def _post_load(self): if self._optimizer is not None: update_optimizer_parameters(self._optimizer, list(self.policy.parameters()))