import numpy as np
from mushroom_rl.algorithms.actor_critic.deep_actor_critic import DeepAC
from mushroom_rl.policy import Policy
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.utils.replay_memory import ReplayMemory
from mushroom_rl.utils.parameters import Parameter, to_parameter
from copy import deepcopy
[docs]class DDPG(DeepAC):
"""
Deep Deterministic Policy Gradient algorithm.
"Continuous Control with Deep Reinforcement Learning".
Lillicrap T. P. et al.. 2016.
"""
[docs] def __init__(self, mdp_info, policy_class, policy_params,
actor_params, actor_optimizer, critic_params, batch_size,
initial_replay_size, max_replay_size, tau, policy_delay=1,
critic_fit_params=None, actor_predict_params=None, critic_predict_params=None):
"""
Constructor.
Args:
policy_class (Policy): class of the policy;
policy_params (dict): parameters of the policy to build;
actor_params (dict): parameters of the actor approximator to
build;
actor_optimizer (dict): parameters to specify the actor optimizer
algorithm;
critic_params (dict): parameters of the critic approximator to
build;
batch_size ([int, Parameter]): the number of samples in a batch;
initial_replay_size (int): the number of samples to collect before
starting the learning;
max_replay_size (int): the maximum number of samples in the replay
memory;
tau ((float, Parameter)): value of coefficient for soft updates;
policy_delay ([int, Parameter], 1): the number of updates of the critic after
which an actor update is implemented;
critic_fit_params (dict, None): parameters of the fitting algorithm
of the critic approximator;
actor_predict_params (dict, None): parameters for the prediction with the
actor approximator;
critic_predict_params (dict, None): parameters for the prediction with the
critic approximator.
"""
self._critic_fit_params = dict() if critic_fit_params is None else critic_fit_params
self._actor_predict_params = dict() if actor_predict_params is None else actor_predict_params
self._critic_predict_params = dict() if critic_predict_params is None else critic_predict_params
self._batch_size = to_parameter(batch_size)
self._tau = to_parameter(tau)
self._policy_delay = to_parameter(policy_delay)
self._fit_count = 0
self._replay_memory = ReplayMemory(initial_replay_size, max_replay_size)
target_critic_params = deepcopy(critic_params)
self._critic_approximator = Regressor(TorchApproximator,
**critic_params)
self._target_critic_approximator = Regressor(TorchApproximator,
**target_critic_params)
target_actor_params = deepcopy(actor_params)
self._actor_approximator = Regressor(TorchApproximator,
**actor_params)
self._target_actor_approximator = Regressor(TorchApproximator,
**target_actor_params)
self._init_target(self._critic_approximator,
self._target_critic_approximator)
self._init_target(self._actor_approximator,
self._target_actor_approximator)
policy = policy_class(self._actor_approximator, **policy_params)
policy_parameters = self._actor_approximator.model.network.parameters()
self._add_save_attr(
_critic_fit_params='pickle',
_critic_predict_params='pickle',
_actor_predict_params='pickle',
_batch_size='mushroom',
_tau='mushroom',
_policy_delay='mushroom',
_fit_count='primitive',
_replay_memory='mushroom',
_critic_approximator='mushroom',
_target_critic_approximator='mushroom',
_target_actor_approximator='mushroom'
)
super().__init__(mdp_info, policy, actor_optimizer, policy_parameters)
[docs] def fit(self, dataset, **info):
self._replay_memory.add(dataset)
if self._replay_memory.initialized:
state, action, reward, next_state, absorbing, _ =\
self._replay_memory.get(self._batch_size())
q_next = self._next_q(next_state, absorbing)
q = reward + self.mdp_info.gamma * q_next
self._critic_approximator.fit(state, action, q,
**self._critic_fit_params)
if self._fit_count % self._policy_delay() == 0:
loss = self._loss(state)
self._optimize_actor_parameters(loss)
self._update_target(self._critic_approximator,
self._target_critic_approximator)
self._update_target(self._actor_approximator,
self._target_actor_approximator)
self._fit_count += 1
def _loss(self, state):
action = self._actor_approximator(state, output_tensor=True, **self._actor_predict_params)
q = self._critic_approximator(state, action, output_tensor=True, **self._critic_predict_params)
return -q.mean()
[docs] def _next_q(self, next_state, absorbing):
"""
Args:
next_state (np.ndarray): the states where next action has to be
evaluated;
absorbing (np.ndarray): the absorbing flag for the states in
``next_state``.
Returns:
Action-values returned by the critic for ``next_state`` and the
action returned by the actor.
"""
a = self._target_actor_approximator.predict(next_state, **self._actor_predict_params)
q = self._target_critic_approximator.predict(next_state, a, **self._critic_predict_params)
q *= 1 - absorbing
return q
[docs] def _post_load(self):
self._actor_approximator = self.policy._approximator
self._update_optimizer_parameters(self._actor_approximator.model.network.parameters())