Skip to content

Instantly share code, notes, and snippets.

@leopd
Last active July 30, 2020 19:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leopd/181b91bf999b3040b9269f9e2952097d to your computer and use it in GitHub Desktop.
Save leopd/181b91bf999b3040b9269f9e2952097d to your computer and use it in GitHub Desktop.
Rough idea of how to write a block-oriented prefetching dataset wrapper for pytorch.
import functools
from torch.utils.data import Dataset
class BlockCachingDatasetWrapper(Dataset):
"""Wraps a pytorch dataset with an LRU cache
that fetches an entire block of records at once.
"""
def __init__(self, base_dataset:Dataset, block_size:int=16):
self._dataset = base_dataset
self._blocksize = block_size
def __len__(self) -> int:
return len(self._dataset)
@functools.lru_cache(1000000)
def _cached_read(self, n:int) -> "Record":
return self._dataset[n]
def __get__(self, n:int) -> "Record":
block_start = (n // self._blocksize) * self._blocksize
for i in range(block_start, block_start + self._blocksize):
if i < len(self): # don't run past the end
_ = self._cached_read(i)
return self._cached_read(n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment