from mushroom_rl.algorithms.value.td import TD
from mushroom_rl.utils.eligibility_trace import EligibilityTrace
from mushroom_rl.utils.table import Table
[docs]class SARSALambda(TD):
"""
The SARSA(lambda) algorithm for finite MDPs.
"""
[docs] def __init__(self, mdp_info, policy, learning_rate, lambda_coeff,
trace='replacing'):
"""
Constructor.
Args:
lambda_coeff (float): eligibility trace coefficient;
trace (str, 'replacing'): type of eligibility trace to use.
"""
Q = Table(mdp_info.size)
self._lambda = lambda_coeff
self.e = EligibilityTrace(Q.shape, trace)
self._add_save_attr(
_lambda='primitive',
e='pickle'
)
super().__init__(mdp_info, policy, Q, learning_rate)
[docs] def _update(self, state, action, reward, next_state, absorbing):
q_current = self.Q[state, action]
self.next_action = self.draw_action(next_state)
q_next = self.Q[next_state, self.next_action] if not absorbing else 0.
delta = reward + self.mdp_info.gamma * q_next - q_current
self.e.update(state, action)
self.Q.table += self.alpha(state, action) * delta * self.e.table
self.e.table *= self.mdp_info.gamma * self._lambda
[docs] def episode_start(self):
self.e.reset()
super().episode_start()