Skip to content

Instantly share code, notes, and snippets.

@dwf
Created February 24, 2015 22:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dwf/fd13ca098b1cb45e9011 to your computer and use it in GitHub Desktop.
Save dwf/fd13ca098b1cb45e9011 to your computer and use it in GitHub Desktop.
Multithreaded buffering NPY reader.
try:
from queue import Queue, Empty
except ImportError:
from Queue import Queue, Empty
import threading
import numpy
from numpy.lib.format import read_magic, read_array_header_1_0
class BufferedChunkedNPY(object):
"""A class that manages reading chunks from an NPY file on disk.
Parameters
----------
fp : file-like
An open file handle.
chunk_size : int
The number of items along the first dimension to read as
each chunk.
Notes
-----
This dispatches a background thread after each subsequent read
to pre-fetch the next chunk.
"""
def _get_header(self):
self._fp.seek(0)
major, minor = read_magic(self._fp)
if (major, minor) != (1, 0):
raise NotImplementedError()
header = read_array_header_1_0(self._fp)
if header['fortran_order']:
raise NotImplementedError()
return header
def __init__(self, fp, chunk_size):
header = self._get_header()
# Not intended to be modified after object creation.
self._chunk_size = chunk_size
self._dtype = numpy.dtype(header['descr'])
self._shape = header['shape']
self._chunk_items = chunk_size * int(numpy.product(self._shape[1:]))
self._total = int(numpy.product(self._shape))
# Mutable state.
self._fp = fp
self._items_remaining = self._total
self._thread = None
self._queue = Queue()
def fetch(self):
items = min(self._items_remaining, self._chunk_items)
if items == 0:
return
chunk_shape = (items // self._chunk_items,) + self._shape[1:]
# The numpy.fromfile call will release the GIL, allowing other threads
# to run. It's therefore probably a good idea to set chunk_size to be
# large enough that this takes approximately the same time as it does
# to process the number of batches you get from the chunk.
self._queue.put(
numpy.fromfile(self._fp, self._dtype, items).reshape(chunk_shape))
self._items_remaining -= items
def next_chunk(self):
if self._thread is None:
self._thread = threading.Thread(target=self.fetch)
self._thread.start()
self._thread.join()
try:
next_chunk = self._queue.get(block=False)
except Empty:
assert self._items_remaining == 0
raise
# After retrieving the
self._thread = threading.Thread(self.fetch)
self._thread.start()
return next_chunk
def reset(self):
self._items_remaining = self._total
self._fp.seek(0)
class BatchesFromChunksIterator(object):
"""Iterator that returns batches from a BufferedChunkedNPY reader."""
def __init__(self, chunked_reader, batch_size):
self._reader = chunked_reader
self._current_chunk = None
self._batch_size = batch_size
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
if not self._current_chunk:
try:
self._current_chunk = self._reader.next_chunk()
self._offset = 0
return self.next()
except Empty:
raise StopIteration
else:
batch = self._current_chunk[self._offset:self._offset +
self._batch_size]
self._offset += self._batch_size
if self._offset >= self._current_chunk.shape[0]:
self._current_chunk = None
return batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment