from copy import deepcopy
import torch.nn as nn
import torch.nn.functional as F
from mushroom_rl.algorithms.value.dqn import AbstractDQN
from mushroom_rl.algorithms.value.dqn.categorical_dqn import categorical_loss
from mushroom_rl.algorithms.value.dqn.noisy_dqn import NoisyNetwork
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.utils.replay_memory import PrioritizedReplayMemory
eps = torch.finfo(torch.float32).eps
class RainbowNetwork(nn.Module):
def __init__(self, input_shape, output_shape, features_network, n_atoms,
v_min, v_max, n_features, use_cuda, sigma_coeff, **kwargs):
super().__init__()
self._n_output = output_shape[0]
self._phi = features_network(input_shape, (n_features,),
n_features=n_features, **kwargs)
self._n_atoms = n_atoms
self._v_min = v_min
self._v_max = v_max
delta = (self._v_max - self._v_min) / (self._n_atoms - 1)
self._a_values = torch.arange(self._v_min, self._v_max + eps, delta)
if use_cuda:
self._a_values = self._a_values.cuda()
self._pv = NoisyNetwork.NoisyLinear(n_features, n_atoms, use_cuda, sigma_coeff)
self._pa = nn.ModuleList(
[NoisyNetwork.NoisyLinear(n_features, n_atoms, use_cuda, sigma_coeff) for _ in range(self._n_output)])
def forward(self, state, action=None, get_distribution=False):
features = self._phi(state)
a_pv = self._pv(features)
a_pa = [self._pa[i](features) for i in range(self._n_output)]
a_pa = torch.stack(a_pa, dim=1)
a_pv = a_pv.unsqueeze(1).repeat(1, self._n_output, 1)
mean_a_pa = a_pa.mean(1, keepdim=True).repeat(1, self._n_output, 1)
softmax = F.softmax(a_pv + a_pa - mean_a_pa, dim=-1)
if not get_distribution:
q = torch.empty(softmax.shape[:-1])
for i in range(softmax.shape[0]):
q[i] = softmax[i] @ self._a_values
if action is not None:
return torch.squeeze(q.gather(1, action))
else:
return q
else:
if action is not None:
action = torch.unsqueeze(
action.long(), 2).repeat(1, 1, self._n_atoms)
return torch.squeeze(softmax.gather(1, action))
else:
return softmax
[docs]class Rainbow(AbstractDQN):
"""
Rainbow algorithm.
"Rainbow: Combinining Improvements in Deep Reinforcement Learning".
Hessel M. et al.. 2018.
"""
[docs] def __init__(self, mdp_info, policy, approximator_params, n_atoms, v_min,
v_max, n_steps_return, alpha_coeff, beta, sigma_coeff=.5,
**params):
"""
Constructor.
Args:
n_atoms (int): number of atoms;
v_min (float): minimum value of value-function;
v_max (float): maximum value of value-function;
n_steps_return (int): the number of steps to consider to compute the n-return;
alpha_coeff (float): prioritization exponent for prioritized experience replay;
beta (Parameter): importance sampling coefficient for prioritized experience replay;
sigma_coeff (float, .5): sigma0 coefficient for noise initialization in noisy layers.
"""
features_network = approximator_params['network']
params['approximator_params'] = deepcopy(approximator_params)
params['approximator_params']['network'] = RainbowNetwork
params['approximator_params']['features_network'] = features_network
params['approximator_params']['n_atoms'] = n_atoms
params['approximator_params']['v_min'] = v_min
params['approximator_params']['v_max'] = v_max
params['approximator_params']['sigma_coeff'] = sigma_coeff
params['approximator_params']['loss'] = categorical_loss
self._n_atoms = n_atoms
self._v_min = v_min
self._v_max = v_max
self._delta = (v_max - v_min) / (n_atoms - 1)
self._a_values = np.arange(v_min, v_max + eps, self._delta)
self._n_steps_return = n_steps_return
self._sigma_coeff = sigma_coeff
params['replay_memory'] = PrioritizedReplayMemory(
params['initial_replay_size'], params['max_replay_size'], alpha=alpha_coeff,
beta=beta
)
self._add_save_attr(
_n_atoms='primitive',
_v_min='primitive',
_v_max='primitive',
_delta='primitive',
_a_values='numpy',
_n_steps_return='primitive',
_sigma_coeff='primitive'
)
super().__init__(mdp_info, policy, TorchApproximator, **params)
[docs] def fit(self, dataset, **info):
self._replay_memory.add(dataset, np.ones(len(dataset)) * self._replay_memory.max_priority,
n_steps_return=self._n_steps_return, gamma=self.mdp_info.gamma)
if self._replay_memory.initialized:
state, action, reward, next_state, absorbing, _, idxs, is_weight = \
self._replay_memory.get(self._batch_size())
if self._clip_reward:
reward = np.clip(reward, -1, 1)
q_next = self.approximator.predict(next_state, **self._predict_params)
a_max = np.argmax(q_next, axis=1)
gamma = self.mdp_info.gamma ** self._n_steps_return * (1 - absorbing)
p_next = self.target_approximator.predict(next_state, a_max,
get_distribution=True, **self._predict_params)
gamma_z = gamma.reshape(-1, 1) * np.expand_dims(
self._a_values, 0).repeat(len(gamma), 0)
bell_a = (reward.reshape(-1, 1) + gamma_z).clip(self._v_min,
self._v_max)
b = (bell_a - self._v_min) / self._delta
l = np.floor(b).astype(int)
u = np.ceil(b).astype(int)
m = np.zeros((self._batch_size.get_value(), self._n_atoms))
for i in range(self._n_atoms):
l[:, i][(u[:, i] > 0) * (l[:, i] == u[:, i])] -= 1
u[:, i][(l[:, i] < (self._n_atoms - 1)) * (l[:, i] == u[:, i])] += 1
m[np.arange(len(m)), l[:, i]] += p_next[:, i] * (u[:, i] - b[:, i])
m[np.arange(len(m)), u[:, i]] += p_next[:, i] * (b[:, i] - l[:, i])
kl = -np.sum(m * np.log(self.approximator.predict(state, action, get_distribution=True,
**self._predict_params).clip(1e-5)), 1)
self._replay_memory.update(kl, idxs)
self.approximator.fit(state, action, m, weights=is_weight,
get_distribution=True, **self._fit_params)
self._n_updates += 1
if self._n_updates % self._target_update_frequency == 0:
self._update_target()