Skip to content

Instantly share code, notes, and snippets.

@mccutchen
Last active August 29, 2015 14:07
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mccutchen/7100057aa91d167cc048 to your computer and use it in GitHub Desktop.
Save mccutchen/7100057aa91d167cc048 to your computer and use it in GitHub Desktop.
flexible batching of sequences in Python
def gen_batches(xs, size):
"""
Given a sequence xs and a batch size, yield batches from the sequence as
lists of length size, where the last batch might be smaller than the
rest.
>>> list(gen_batches(range(9), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
>>> list(gen_batches(range(11), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]]
Also works with sequences that don't have a known size:
>>> list(gen_batches(xrange(9), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
>>> import itertools
>>> xs = itertools.cycle('abcd')
>>> list(itertools.islice(gen_batches(xs, 3), 3))
[['a', 'b', 'c'], ['d', 'a', 'b'], ['c', 'd', 'a']]
"""
assert size > 0
acc = []
for i, x in enumerate(xs):
if i and i % size == 0:
yield acc
acc = []
acc.append(x)
if acc:
yield acc
def gen_overlapping_batches(xs, size, overlap=0.0):
"""Given a sequence xs and a batch size, yield batches from the sequence as
lists of length size, where the last batch might be smaller than the
rest.
If an overlap percentage is given, each batch will share that percentage
of elements with the previous and next batch.
For example, with no overlap:
>>> list(gen_overlapping_batches(range(10), 4))
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]]
And with 25% overlap:
>>> list(gen_overlapping_batches(range(10), 4, 0.25))
[[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]]
This should even work well with infinitely long generators:
>>> import itertools
>>> xs = itertools.cycle('abcd')
>>> list(itertools.islice(gen_overlapping_batches(xs, 4, 0.25), 3))
[['a', 'b', 'c', 'd'], ['d', 'a', 'b', 'c'], ['c', 'd', 'a', 'b']]
"""
assert size > 0
assert 0 <= overlap <= 1
offset = int(size * overlap)
acc = []
for i, x in enumerate(xs):
if i and len(acc) % size == 0:
yield acc
acc = acc[-offset:] if offset else []
acc.append(x)
if acc:
yield acc
if __name__ == '__main__':
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment