Source code for mushroom_rl.utils.minibatches

import numpy as np


[docs]def minibatch_number(size, batch_size): """ Function to retrieve the number of batches, given a batch sizes. Args: size (int): size of the dataset; batch_size (int): size of the batches. Returns: The number of minibatches in the dataset. """ return int(np.ceil(size / batch_size))
[docs]def minibatch_generator(batch_size, *dataset): """ Generator that creates a minibatch from the full dataset. Args: batch_size (int): the maximum size of each minibatch; dataset: the dataset to be splitted. Returns: The current minibatch. """ size = len(dataset[0]) num_batches = minibatch_number(size, batch_size) indexes = np.arange(0, size, 1) np.random.shuffle(indexes) batches = [(i * batch_size, min(size, (i + 1) * batch_size)) for i in range(0, num_batches)] for (batch_start, batch_end) in batches: batch = [] for i in range(len(dataset)): batch.append(dataset[i][indexes[batch_start:batch_end]]) yield batch