Source code for mushroom_rl.algorithms.actor_critic.deep_actor_critic.ppo

import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F

from mushroom_rl.algorithms.agent import Agent
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import to_float_tensor
from mushroom_rl.utils.minibatches import minibatch_generator
from mushroom_rl.utils.dataset import parse_dataset, compute_J
from mushroom_rl.utils.value_functions import compute_gae

[docs]class PPO(Agent): """ Proximal Policy Optimization algorithm. "Proximal Policy Optimization Algorithms". Schulman J. et al.. 2017. """
[docs] def __init__(self, mdp_info, policy, actor_optimizer, critic_params, n_epochs_policy, batch_size, eps_ppo, lam, quiet=True, critic_fit_params=None): """ Constructor. Args: policy (TorchPolicy): torch policy to be learned by the algorithm actor_optimizer (dict): parameters to specify the actor optimizer algorithm; critic_params (dict): parameters of the critic approximator to build; n_epochs_policy (int): number of policy updates for every dataset; batch_size (int): size of minibatches for every optimization step eps_ppo (float): value for probability ratio clipping; lam float(float, 1.): lambda coefficient used by generalized advantage estimation; quiet (bool, True): if true, the algorithm will print debug information; critic_fit_params (dict, None): parameters of the fitting algorithm of the critic approximator. """ self._critic_fit_params = dict(n_epochs=10) if critic_fit_params is None else critic_fit_params self._n_epochs_policy = n_epochs_policy self._batch_size = batch_size self._eps_ppo = eps_ppo self._optimizer = actor_optimizer['class'](policy.parameters(), **actor_optimizer['params']) self._lambda = lam self._V = Regressor(TorchApproximator, **critic_params) self._quiet = quiet self._iter = 1 self._add_save_attr( _critic_fit_params='pickle', _n_epochs_policy='primitive', _batch_size='primitive', _eps_ppo='primitive', _optimizer='torch', _lambda='primitive', _V='mushroom', _quiet='primitive', _iter='primitive' ) super().__init__(mdp_info, policy, None)
[docs] def fit(self, dataset): if not self._quiet: tqdm.write('Iteration ' + str(self._iter)) x, u, r, xn, absorbing, last = parse_dataset(dataset) x = x.astype(np.float32) u = u.astype(np.float32) r = r.astype(np.float32) xn = xn.astype(np.float32) obs = to_float_tensor(x, self.policy.use_cuda) act = to_float_tensor(u, self.policy.use_cuda) v_target, np_adv = compute_gae(self._V, x, xn, r, absorbing, last, self.mdp_info.gamma, self._lambda) np_adv = (np_adv - np.mean(np_adv)) / (np.std(np_adv) + 1e-8) adv = to_float_tensor(np_adv, self.policy.use_cuda) old_pol_dist = self.policy.distribution_t(obs) old_log_p = old_pol_dist.log_prob(act)[:, None].detach(), v_target, **self._critic_fit_params) self._update_policy(obs, act, adv, old_log_p) # Print fit information self._print_fit_info(dataset, x, v_target, old_pol_dist) self._iter += 1
def _update_policy(self, obs, act, adv, old_log_p): for epoch in range(self._n_epochs_policy): for obs_i, act_i, adv_i, old_log_p_i in minibatch_generator( self._batch_size, obs, act, adv, old_log_p): self._optimizer.zero_grad() prob_ratio = torch.exp( self.policy.log_prob_t(obs_i, act_i) - old_log_p_i ) clipped_ratio = torch.clamp(prob_ratio, 1 - self._eps_ppo, 1 + self._eps_ppo) loss = -torch.mean(torch.min(prob_ratio * adv_i, clipped_ratio * adv_i)) loss.backward() self._optimizer.step() def _print_fit_info(self, dataset, x, v_target, old_pol_dist): if not self._quiet: logging_verr = [] torch_v_targets = torch.tensor(v_target, dtype=torch.float) for idx in range(len(self._V)): v_pred = torch.tensor(self._V(x, idx=idx), dtype=torch.float) v_err = F.mse_loss(v_pred, torch_v_targets) logging_verr.append(v_err.item()) logging_ent = self.policy.entropy(x) new_pol_dist = self.policy.distribution(x) logging_kl = torch.mean(torch.distributions.kl.kl_divergence( new_pol_dist, old_pol_dist)) avg_rwd = np.mean(compute_J(dataset)) tqdm.write("Iterations Results:\n\trewards {} vf_loss {}\n\tentropy {} kl {}".format( avg_rwd, logging_verr, logging_ent, logging_kl)) tqdm.write( '--------------------------------------------------------------------------------------------------')