Skip to content

Instantly share code, notes, and snippets.

@leopd
Last active June 8, 2019 15:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leopd/4adf59135049641916d41efe50f0af16 to your computer and use it in GitHub Desktop.
Save leopd/4adf59135049641916d41efe50f0af16 to your computer and use it in GitHub Desktop.
PyTorch Dataset class to access line-delimited text files too big to hold in memory.
from functools import lru_cache
import subprocess
from torch.utils.data import Dataset
class FileReaderDataset(Dataset):
"""Exposes a line-delimited text file as a PyTorch Dataset.
Maintains an LRU cache of lines it has read, while supporting random access into
files too large to hold in memory. Memory requirement still scales by O(N), but just
for pointers into the file, about 8 bytes per line. After the file has been scanned,
random access will be very fast - as fast as the disk plus the OS's cache of it.
"""
def __init__(self, filename:str, line_cache_size:int=1048576):
super().__init__()
self._filename = filename
self._filehandle = open(filename,"r")
self._pos = 0
self._linenum = 0
self._lineseeks = [0] # list of seek-byte-offset in file for every line we've read to.
self._cached_getitem = lru_cache(maxsize=line_cache_size)(self._getitem)
self._file_len = None
def _readnextline(self) -> str:
#print(f"Reading line {self._linenum}")
line = next(self._filehandle)
self._pos += len(line)
self._linenum += 1
if len(self._lineseeks) == self._linenum:
self._lineseeks.append(self._pos)
return line
def _seektoline(self, linenum:int) -> None:
pos = self._lineseeks[linenum]
self._filehandle.seek(pos)
self._pos = pos
self._linenum = linenum
def __getitem__(self, n:int) -> str:
return self._cached_getitem(n)
def _getitem(self, n:int) -> str:
"""Uncached version of __getitem__
"""
if n == self._linenum:
# Next line, just read it
return self._readnextline()
if n < len(self._lineseeks):
# Seek back.
self._seektoline(n)
return self._readnextline()
# Seek forward, reading.
while self._linenum < n:
#print(f"Seeking {self._linenum} to {n}")
#NOTE: This isn't caching the lines we scan through, but that logic is a bit tricky.
self._readnextline()
assert n == self._linenum
return self._readnextline()
def __len__(self) -> int:
if self._file_len is None:
out = subprocess.check_output(["wc", "-l", self._filename])
numstr, _ = out.split(b" ") # out looks like "1234 your_filename"
self._file_len = int(numstr)
return self._file_len
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment