Source code for mushroom_rl.utils.replay_memory

import numpy as np
from mushroom_rl.core import Serializable


[docs]class ReplayMemory(Serializable): """ This class implements function to manage a replay memory as the one used in "Human-Level Control Through Deep Reinforcement Learning" by Mnih V. et al.. """
[docs] def __init__(self, initial_size, max_size): """ Constructor. Args: initial_size (int): initial number of elements in the replay memory; max_size (int): maximum number of elements that the replay memory can contain. """ self._initial_size = initial_size self._max_size = max_size self.reset() self._add_save_attr( _initial_size='pickle', _max_size='pickle', _idx='pickle!', _full='pickle!', _states='pickle!', _actions='pickle!', _rewards='pickle!', _next_states='pickle!', _absorbing='pickle!', _last='pickle!' )
[docs] def add(self, dataset): """ Add elements to the replay memory. Args: dataset (list): list of elements to add to the replay memory. """ for i in range(len(dataset)): self._states[self._idx] = dataset[i][0] self._actions[self._idx] = dataset[i][1] self._rewards[self._idx] = dataset[i][2] self._next_states[self._idx] = dataset[i][3] self._absorbing[self._idx] = dataset[i][4] self._last[self._idx] = dataset[i][5] self._idx += 1 if self._idx == self._max_size: self._full = True self._idx = 0
[docs] def get(self, n_samples): """ Returns the provided number of states from the replay memory. Args: n_samples (int): the number of samples to return. Returns: The requested number of samples. """ s = list() a = list() r = list() ss = list() ab = list() last = list() for i in np.random.randint(self.size, size=n_samples): s.append(np.array(self._states[i])) a.append(self._actions[i]) r.append(self._rewards[i]) ss.append(np.array(self._next_states[i])) ab.append(self._absorbing[i]) last.append(self._last[i]) return np.array(s), np.array(a), np.array(r), np.array(ss),\ np.array(ab), np.array(last)
[docs] def reset(self): """ Reset the replay memory. """ self._idx = 0 self._full = False self._states = [None for _ in range(self._max_size)] self._actions = [None for _ in range(self._max_size)] self._rewards = [None for _ in range(self._max_size)] self._next_states = [None for _ in range(self._max_size)] self._absorbing = [None for _ in range(self._max_size)] self._last = [None for _ in range(self._max_size)]
@property def initialized(self): """ Returns: Whether the replay memory has reached the number of elements that allows it to be used. """ return self.size > self._initial_size @property def size(self): """ Returns: The number of elements contained in the replay memory. """ return self._idx if not self._full else self._max_size
[docs] def _post_load(self): if self._full is None: self.reset()
[docs]class SumTree(object): """ This class implements a sum tree data structure. This is used, for instance, by ``PrioritizedReplayMemory``. """
[docs] def __init__(self, max_size): """ Constructor. Args: max_size (int): maximum size of the tree. """ self._max_size = max_size self._tree = np.zeros(2 * max_size - 1) self._data = [None for _ in range(max_size)] self._idx = 0 self._full = False
[docs] def add(self, dataset, priority): """ Add elements to the tree. Args: dataset (list): list of elements to add to the tree; p (np.ndarray): priority of each sample in the dataset. """ for d, p in zip(dataset, priority): idx = self._idx + self._max_size - 1 self._data[self._idx] = d self.update([idx], [p]) self._idx += 1 if self._idx == self._max_size: self._idx = 0 self._full = True
[docs] def get(self, s): """ Returns the provided number of states from the replay memory. Args: s (float): the value of the samples to return. Returns: The requested sample. """ idx = self._retrieve(s, 0) data_idx = idx - self._max_size + 1 return idx, self._tree[idx], self._data[data_idx]
[docs] def update(self, idx, priorities): """ Update the priority of the sample at the provided index in the dataset. Args: idx (np.ndarray): indexes of the transitions in the dataset; priorities (np.ndarray): priorities of the transitions. """ for i, p in zip(idx, priorities): delta = p - self._tree[i] self._tree[i] = p self._propagate(delta, i)
def _propagate(self, delta, idx): parent_idx = (idx - 1) // 2 self._tree[parent_idx] += delta if parent_idx != 0: self._propagate(delta, parent_idx) def _retrieve(self, s, idx): left = 2 * idx + 1 right = left + 1 if left >= len(self._tree): return idx if self._tree[left] == self._tree[right]: return self._retrieve(s, np.random.choice([left, right])) if s <= self._tree[left]: return self._retrieve(s, left) else: return self._retrieve(s - self._tree[left], right) @property def size(self): """ Returns: The current size of the tree. """ return self._idx if not self._full else self._max_size @property def max_p(self): """ Returns: The maximum priority among the ones in the tree. """ return self._tree[-self._max_size:].max() @property def total_p(self): """ Returns: The sum of the priorities in the tree, i.e. the value of the root node. """ return self._tree[0]
[docs]class PrioritizedReplayMemory(Serializable): """ This class implements function to manage a prioritized replay memory as the one used in "Prioritized Experience Replay" by Schaul et al., 2015. """
[docs] def __init__(self, initial_size, max_size, alpha, beta, epsilon=.01): """ Constructor. Args: initial_size (int): initial number of elements in the replay memory; max_size (int): maximum number of elements that the replay memory can contain; alpha (float): prioritization coefficient; beta (float): importance sampling coefficient; epsilon (float, .01): small value to avoid zero probabilities. """ self._initial_size = initial_size self._max_size = max_size self._alpha = alpha self._beta = beta self._epsilon = epsilon self._tree = SumTree(max_size) self._add_save_attr( _initial_size='pickle', _max_size='pickle', _alpha='pickle', _beta='pickle', _epsilon='pickle', _tree='pickle!' )
[docs] def add(self, dataset, p): """ Add elements to the replay memory. Args: dataset (list): list of elements to add to the replay memory; p (np.ndarray): priority of each sample in the dataset. """ self._tree.add(dataset, p)
[docs] def get(self, n_samples): """ Returns the provided number of states from the replay memory. Args: n_samples (int): the number of samples to return. Returns: The requested number of samples. """ states = [None for _ in range(n_samples)] actions = [None for _ in range(n_samples)] rewards = [None for _ in range(n_samples)] next_states = [None for _ in range(n_samples)] absorbing = [None for _ in range(n_samples)] last = [None for _ in range(n_samples)] idxs = np.zeros(n_samples, dtype=np.int) priorities = np.zeros(n_samples) total_p = self._tree.total_p segment = total_p / n_samples a = np.arange(n_samples) * segment b = np.arange(1, n_samples + 1) * segment samples = np.random.uniform(a, b) for i, s in enumerate(samples): idx, p, data = self._tree.get(s) idxs[i] = idx priorities[i] = p states[i], actions[i], rewards[i], next_states[i], absorbing[i],\ last[i] = data states[i] = np.array(states[i]) next_states[i] = np.array(next_states[i]) sampling_probabilities = priorities / self._tree.total_p is_weight = (self._tree.size * sampling_probabilities) ** -self._beta() is_weight /= is_weight.max() return np.array(states), np.array(actions), np.array(rewards),\ np.array(next_states), np.array(absorbing), np.array(last),\ idxs, is_weight
[docs] def update(self, error, idx): """ Update the priority of the sample at the provided index in the dataset. Args: error (np.ndarray): errors to consider to compute the priorities; idx (np.ndarray): indexes of the transitions in the dataset. """ p = self._get_priority(error) self._tree.update(idx, p)
def _get_priority(self, error): return (np.abs(error) + self._epsilon) ** self._alpha @property def initialized(self): """ Returns: Whether the replay memory has reached the number of elements that allows it to be used. """ return self._tree.size > self._initial_size @property def max_priority(self): """ Returns: The maximum value of priority inside the replay memory. """ return self._tree.max_p if self.initialized else 1.
[docs] def _post_load(self): if self._tree is None: self._tree = SumTree(self._max_size)