Source code for mushroom_rl.algorithms.actor_critic.deep_actor_critic.trpo

import numpy as np
from tqdm import tqdm

from copy import deepcopy

import torch
import torch.nn.functional as F

from mushroom_rl.algorithms.agent import Agent
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.torch import get_gradient, zero_grad, to_float_tensor
from mushroom_rl.utils.dataset import parse_dataset, compute_J
from mushroom_rl.utils.value_functions import compute_gae

[docs]class TRPO(Agent): """ Trust Region Policy optimization algorithm. "Trust Region Policy Optimization". Schulman J. et al.. 2015. """
[docs] def __init__(self, mdp_info, policy, critic_params, ent_coeff=0., max_kl=.001, lam=1., n_epochs_line_search=10, n_epochs_cg=10, cg_damping=1e-2, cg_residual_tol=1e-10, quiet=True, critic_fit_params=None): """ Constructor. Args: policy (TorchPolicy): torch policy to be learned by the algorithm critic_params (dict): parameters of the critic approximator to build; ent_coeff (float, 0): coefficient for the entropy penalty; max_kl (float, .001): maximum kl allowed for every policy update; lam float(float, 1.): lambda coefficient used by generalized advantage estimation; n_epochs_line_search (int, 10): maximum number of iterations of the line search algorithm; n_epochs_cg (int, 10): maximum number of iterations of the conjugate gradient algorithm; cg_damping (float, 1e-2): damping factor for the conjugate gradient algorithm; cg_residual_tol (float, 1e-10): conjugate gradient residual tolerance; quiet (bool, True): if true, the algorithm will print debug information; critic_fit_params (dict, None): parameters of the fitting algorithm of the critic approximator. """ self._critic_fit_params = dict(n_epochs=5) if critic_fit_params is None else critic_fit_params self._n_epochs_line_search = n_epochs_line_search self._n_epochs_cg = n_epochs_cg self._cg_damping = cg_damping self._cg_residual_tol = cg_residual_tol self._max_kl = max_kl self._ent_coeff = ent_coeff self._lambda = lam self._V = Regressor(TorchApproximator, **critic_params) self._iter = 1 self._quiet = quiet self._old_policy = None self._add_save_attr( _critic_fit_params='pickle', _n_epochs_line_search='primitive', _n_epochs_cg='primitive', _cg_damping='primitive', _cg_residual_tol='primitive', _max_kl='primitive', _ent_coeff='primitive', _lambda='primitive', _V='mushroom', _old_policy='mushroom', _iter='primitive', _quiet='primitive' ) super().__init__(mdp_info, policy, None)
[docs] def fit(self, dataset): if not self._quiet: tqdm.write('Iteration ' + str(self._iter)) state, action, reward, next_state, absorbing, last = parse_dataset(dataset) x = state.astype(np.float32) u = action.astype(np.float32) r = reward.astype(np.float32) xn = next_state.astype(np.float32) obs = to_float_tensor(x, self.policy.use_cuda) act = to_float_tensor(u, self.policy.use_cuda) v_target, np_adv = compute_gae(self._V, x, xn, r, absorbing, last, self.mdp_info.gamma, self._lambda) np_adv = (np_adv - np.mean(np_adv)) / (np.std(np_adv) + 1e-8) adv = to_float_tensor(np_adv, self.policy.use_cuda) # Policy update self._old_policy = deepcopy(self.policy) old_pol_dist = self._old_policy.distribution_t(obs) old_log_prob = self._old_policy.log_prob_t(obs, act).detach() zero_grad(self.policy.parameters()) loss = self._compute_loss(obs, act, adv, old_log_prob) prev_loss = loss.item() # Compute Gradient loss.backward() g = get_gradient(self.policy.parameters()) # Compute direction through conjugate gradient stepdir = self._conjugate_gradient(g, obs, old_pol_dist) # Line search self._line_search(obs, act, adv, old_log_prob, old_pol_dist, prev_loss, stepdir) # VF update, v_target, **self._critic_fit_params) # Print fit information self._print_fit_info(dataset, x, v_target, old_pol_dist) self._iter += 1
def _fisher_vector_product(self, p, obs, old_pol_dist): p_tensor = torch.from_numpy(p) if self.policy.use_cuda: p_tensor = p_tensor.cuda() return self._fisher_vector_product_t(p_tensor, obs, old_pol_dist) def _fisher_vector_product_t(self, p, obs, old_pol_dist): kl = self._compute_kl(obs, old_pol_dist) grads = torch.autograd.grad(kl, self.policy.parameters(), create_graph=True) flat_grad_kl =[grad.view(-1) for grad in grads]) kl_v = torch.sum(flat_grad_kl * p) grads_v = torch.autograd.grad(kl_v, self.policy.parameters(), create_graph=False) flat_grad_grad_kl =[grad.contiguous().view(-1) for grad in grads_v]).data return flat_grad_grad_kl + p * self._cg_damping def _conjugate_gradient(self, b, obs, old_pol_dist): p = b.detach().cpu().numpy() r = b.detach().cpu().numpy() x = np.zeros_like(p) r2 = for i in range(self._n_epochs_cg): z = self._fisher_vector_product(p, obs, old_pol_dist).detach().cpu().numpy() v = r2 / x += v * p r -= v * z r2_new = mu = r2_new / r2 p = r + mu * p r2 = r2_new if r2 < self._cg_residual_tol: break return x def _line_search(self, obs, act, adv, old_log_prob, old_pol_dist, prev_loss, stepdir): # Compute optimal step size direction = self._fisher_vector_product(stepdir, obs, old_pol_dist).detach().cpu().numpy() shs = .5 * lm = np.sqrt(shs / self._max_kl) full_step = stepdir / lm stepsize = 1. # Save old policy parameters theta_old = self.policy.get_weights() # Perform Line search violation = True for _ in range(self._n_epochs_line_search): theta_new = theta_old + full_step * stepsize self.policy.set_weights(theta_new) new_loss = self._compute_loss(obs, act, adv, old_log_prob) kl = self._compute_kl(obs, old_pol_dist) improve = new_loss - prev_loss if kl <= self._max_kl * 1.5 and improve >= 0: violation = False break stepsize *= .5 if violation: self.policy.set_weights(theta_old) def _compute_kl(self, obs, old_pol_dist): new_pol_dist = self.policy.distribution_t(obs) return torch.mean(torch.distributions.kl.kl_divergence(old_pol_dist, new_pol_dist)) def _compute_loss(self, obs, act, adv, old_log_prob): ratio = torch.exp(self.policy.log_prob_t(obs, act) - old_log_prob) J = torch.mean(ratio * adv) return J + self._ent_coeff * self.policy.entropy_t(obs) def _print_fit_info(self, dataset, x, v_target, old_pol_dist): if not self._quiet: logging_verr = [] torch_v_targets = torch.tensor(v_target, dtype=torch.float) for idx in range(len(self._V)): v_pred = torch.tensor(self._V(x, idx=idx), dtype=torch.float) v_err = F.mse_loss(v_pred, torch_v_targets) logging_verr.append(v_err.item()) logging_ent = self.policy.entropy(x) new_pol_dist = self.policy.distribution(x) logging_kl = torch.mean( torch.distributions.kl.kl_divergence(old_pol_dist, new_pol_dist) ) avg_rwd = np.mean(compute_J(dataset)) tqdm.write("Iterations Results:\n\trewards {} vf_loss {}\n\tentropy {} kl {}".format( avg_rwd, logging_verr, logging_ent, logging_kl)) tqdm.write( '--------------------------------------------------------------------------------------------------')