Skip to content

Instantly share code, notes, and snippets.

@nicolasdespres
Created March 8, 2017 08:13
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nicolasdespres/bbc62cb43f1ffe9b81d971cb957650c2 to your computer and use it in GitHub Desktop.
Save nicolasdespres/bbc62cb43f1ffe9b81d971cb957650c2 to your computer and use it in GitHub Desktop.
Iterate over iterable by batch.
class batch_iter(Iterator):
"""Iterate by batch.
"""
def __init__(self, iterable, batch_size=1,
allow_smaller_final_batch=False):
if not isinstance(batch_size, int):
raise TypeError("batch_size must be int, not {}"
.format(type(batch_size).__name__))
if batch_size <= 0:
raise ValueError("batch_size must be positive")
self._iterable = iterable
self._it = iter(iterable)
self._batch_size = batch_size
self._allow_smaller_final_batch = allow_smaller_final_batch
self.reset()
@property
def batch_size(self):
return self._batch_size
@property
def allow_smaller_final_batch(self):
return self._allow_smaller_final_batch
def __len__(self):
"""Return the total number of iteration.
Works only if the source `iterable` has a len().
"""
q, r = divmod(len(self._iterable), self._batch_size)
if r > 0 and self._allow_smaller_final_batch:
q += 1
return q
def reset(self):
self._step = -1
def __next__(self):
while True:
batch = take_upto(self._it, self._batch_size)
ntaken = len(batch)
if ntaken == 0 \
or (ntaken < self._batch_size \
and not self._allow_smaller_final_batch):
raise StopIteration
self._step += 1
return batch
@property
def step(self):
"""Return the current iteration step.
This is equivalent to calling `enumerate` on this iterator.
"""
return self._step
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment