Skip to content

Instantly share code, notes, and snippets.

@rkhullar
Created April 4, 2023 14:32
Show Gist options
  • Save rkhullar/725b655a99d6ac3f90ced9129b9bb003 to your computer and use it in GitHub Desktop.
Save rkhullar/725b655a99d6ac3f90ced9129b9bb003 to your computer and use it in GitHub Desktop.
python iter batch v2
from typing import Iterator, TypeVar, List
T = TypeVar('T')
def iter_batch(stream: Iterator[T], batch_size: int) -> Iterator[List[T]]:
assert batch_size > 0
buffer, item_count = [None] * batch_size, 0
for item in stream:
buffer_index = item_count % batch_size
buffer[buffer_index] = item
if buffer_index + 1 == batch_size:
yield list(buffer)
item_count += 1
remaining = item_count % batch_size
if remaining > 0:
yield list(buffer[:remaining])
def test_1():
letters = list('abcde')
batches = list(iter_batch(letters, batch_size=3))
for batch in batches:
print(batch)
if __name__ == '__main__':
test_1()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment