Skip to content

Instantly share code, notes, and snippets.

@mbsariyildiz
Last active May 8, 2018 14:20
Show Gist options
  • Save mbsariyildiz/6c4130abdbdc640cbe8e90ec70cedbbf to your computer and use it in GitHub Desktop.
Save mbsariyildiz/6c4130abdbdc640cbe8e90ec70cedbbf to your computer and use it in GitHub Desktop.
Iterator for list of arrays/matrices whose first dimension match
class Iterator(object):
"""
Iterator for list of tensors whose first dimension match.
"""
def __init__(self, tensors, batch_size, allow_smaller=True, shuffle=True):
self.tensors = tensors
self.batch_size = batch_size
self.allow_smaller = allow_smaller
self.shuffle = shuffle
self._s_ix = 0 # index of sample that will be fetched as the first sample in next_batch
self._order = None # order of samples fetched in an epoch
# number of elements in each slice should be equal
n_elems = [slice.shape[0] for slice in tensors]
assert np.all(np.equal(n_elems, n_elems[0]))
# number of samples in each slice and each should be a positive number
self.n_samples = n_elems[0]
assert self.n_samples > 0
assert self.n_samples >= self.batch_size
# check whether there is no not-fetched sample left
self.__check_new_epoch = lambda: self._s_ix >= self.n_samples
if not self.allow_smaller:
# check whether number of remaining not-fetched samples less than the batch size
self.__check_new_epoch = lambda: self.n_samples - self._s_ix < self.batch_size
self.__new_order = lambda: np.random.permutation(self.n_samples)
if not self.shuffle:
self.__new_order = lambda: np.arange(self.n_samples)
self.reset_batch_order()
def reset_batch_order(self):
self._s_ix = 0
self._order = self.__new_order()
def next_batch(self, return_flag=True):
inds = self._order[self._s_ix : self._s_ix + self.batch_size]
batches = [slice[inds] for slice in self.tensors]
self._s_ix += self.batch_size
new_epoch = self.__check_new_epoch()
if new_epoch:
self.reset_batch_order()
return batches + [new_epoch] if return_flag else batches
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment