Source code for mushroom_rl.utils.eligibility_trace

from mushroom_rl.utils.table import Table


[docs]def EligibilityTrace(shape, name='replacing'): """ Factory method to create an eligibility trace of the provided type. Args: shape (list): shape of the eligibility trace table; name (str, 'replacing'): type of the eligibility trace. Returns: The eligibility trace table of the provided shape and type. """ if name == 'replacing': return ReplacingTrace(shape) elif name == 'accumulating': return AccumulatingTrace(shape) else: raise ValueError('Unknown type of trace.')
[docs]class ReplacingTrace(Table): """ Replacing trace. """
[docs] def reset(self): self.table[:] = 0.
[docs] def update(self, state, action): self.table[state, action] = 1.
[docs]class AccumulatingTrace(Table): """ Accumulating trace. """
[docs] def reset(self): self.table[:] = 0.
[docs] def update(self, state, action): self.table[state, action] += 1.