Created
February 24, 2015 22:03
-
-
Save dwf/fd13ca098b1cb45e9011 to your computer and use it in GitHub Desktop.
Multithreaded buffering NPY reader.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
try: | |
from queue import Queue, Empty | |
except ImportError: | |
from Queue import Queue, Empty | |
import threading | |
import numpy | |
from numpy.lib.format import read_magic, read_array_header_1_0 | |
class BufferedChunkedNPY(object): | |
"""A class that manages reading chunks from an NPY file on disk. | |
Parameters | |
---------- | |
fp : file-like | |
An open file handle. | |
chunk_size : int | |
The number of items along the first dimension to read as | |
each chunk. | |
Notes | |
----- | |
This dispatches a background thread after each subsequent read | |
to pre-fetch the next chunk. | |
""" | |
def _get_header(self): | |
self._fp.seek(0) | |
major, minor = read_magic(self._fp) | |
if (major, minor) != (1, 0): | |
raise NotImplementedError() | |
header = read_array_header_1_0(self._fp) | |
if header['fortran_order']: | |
raise NotImplementedError() | |
return header | |
def __init__(self, fp, chunk_size): | |
header = self._get_header() | |
# Not intended to be modified after object creation. | |
self._chunk_size = chunk_size | |
self._dtype = numpy.dtype(header['descr']) | |
self._shape = header['shape'] | |
self._chunk_items = chunk_size * int(numpy.product(self._shape[1:])) | |
self._total = int(numpy.product(self._shape)) | |
# Mutable state. | |
self._fp = fp | |
self._items_remaining = self._total | |
self._thread = None | |
self._queue = Queue() | |
def fetch(self): | |
items = min(self._items_remaining, self._chunk_items) | |
if items == 0: | |
return | |
chunk_shape = (items // self._chunk_items,) + self._shape[1:] | |
# The numpy.fromfile call will release the GIL, allowing other threads | |
# to run. It's therefore probably a good idea to set chunk_size to be | |
# large enough that this takes approximately the same time as it does | |
# to process the number of batches you get from the chunk. | |
self._queue.put( | |
numpy.fromfile(self._fp, self._dtype, items).reshape(chunk_shape)) | |
self._items_remaining -= items | |
def next_chunk(self): | |
if self._thread is None: | |
self._thread = threading.Thread(target=self.fetch) | |
self._thread.start() | |
self._thread.join() | |
try: | |
next_chunk = self._queue.get(block=False) | |
except Empty: | |
assert self._items_remaining == 0 | |
raise | |
# After retrieving the | |
self._thread = threading.Thread(self.fetch) | |
self._thread.start() | |
return next_chunk | |
def reset(self): | |
self._items_remaining = self._total | |
self._fp.seek(0) | |
class BatchesFromChunksIterator(object): | |
"""Iterator that returns batches from a BufferedChunkedNPY reader.""" | |
def __init__(self, chunked_reader, batch_size): | |
self._reader = chunked_reader | |
self._current_chunk = None | |
self._batch_size = batch_size | |
def __iter__(self): | |
return self | |
def __next__(self): | |
return self.next() | |
def next(self): | |
if not self._current_chunk: | |
try: | |
self._current_chunk = self._reader.next_chunk() | |
self._offset = 0 | |
return self.next() | |
except Empty: | |
raise StopIteration | |
else: | |
batch = self._current_chunk[self._offset:self._offset + | |
self._batch_size] | |
self._offset += self._batch_size | |
if self._offset >= self._current_chunk.shape[0]: | |
self._current_chunk = None | |
return batch |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment