Skip to content

Instantly share code, notes, and snippets.

@altescy
Last active October 10, 2023 06:14
Show Gist options
  • Save altescy/8338a0874b6ce5b78bb65bdc025ceb5f to your computer and use it in GitHub Desktop.
Save altescy/8338a0874b6ce5b78bb65bdc025ceb5f to your computer and use it in GitHub Desktop.
Utility functions for iteration in Python.
import collections
import itertools
import math
from collections import abc
from typing import (Any, Callable, Generic, Iterable, Iterator, List, Optional,
TypeVar)
T = TypeVar("T")
class SizedIterator(Generic[T]):
"""
A wrapper for an iterator that knows its size.
Args:
iterator: The iterator.
size: The size of the iterator.
"""
def __init__(self, iterator: Iterator[T], size: int):
self.iterator = iterator
self.size = size
def __iter__(self) -> Iterator[T]:
return self.iterator
def __next__(self) -> T:
return next(self.iterator)
def __len__(self) -> int:
return self.size
def batched(
iterable: Iterable[T], batch_size: int, drop_last: bool = False
) -> Iterator[List[T]]:
"""
Batch an iterable into lists of the given size.
Args:
iterable: The iterable.
batch_size: The size of each batch.
drop_last: Whether to drop the last batch if it is smaller than the given size.
Returns:
An iterator over batches.
"""
def iterator() -> Iterator[List[T]]:
batch = []
for item in iterable:
batch.append(item)
if len(batch) == batch_size:
yield batch
batch = []
if batch and not drop_last:
yield batch
if isinstance(iterable, abc.Sized):
num_batches = (
len(iterable) // batch_size
if drop_last
else math.ceil(len(iterable) / batch_size)
)
return SizedIterator(iterator(), num_batches)
return iterator()
def batched_iterator(iterable: Iterable[T], batch_size: int) -> Iterator[Iterator[T]]:
"""
Batch an iterable into iterators of the given size.
Args:
iterable: The iterable.
batch_size: The size of each batch.
Returns:
An iterator over batches.
"""
def iterator() -> Iterator[Iterator[T]]:
iterator = iter(iterable)
while True:
try:
subiterator = itertools.chain(
[next(iterator)], itertools.islice(iterator, batch_size - 1)
)
yield subiterator
consume(subiterator)
except StopIteration:
break
if isinstance(iterable, abc.Sized):
num_batches = math.ceil(len(iterable) / batch_size)
return SizedIterator(iterator(), num_batches)
return iterator()
def iter_with_callback(
iterable: Iterable[T],
callback: Callable[[T], Any],
) -> Iterator[T]:
"""
Iterate over an iterable and call a callback for each item.
Args:
iterable: The iterable.
callback: The callback to call for each item.
Returns:
An iterator over the iterable.
"""
def iterator() -> Iterator[T]:
for item in iterable:
yield item
callback(item)
if isinstance(iterable, abc.Sized):
return SizedIterator(iterator(), len(iterable))
return iterator()
def consume(iterator: Iterator, n: Optional[int] = None) -> None:
"""
Advance the iterator n-steps ahead. If n is None, consume entirely.
Args:
iterator: The iterator.
n: The number of items to consume. If None, consume entirely.
"""
if n is None:
collections.deque(iterator, maxlen=0)
else:
next(itertools.islice(iterator, n, n), None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment