Skip to content

Instantly share code, notes, and snippets.

@skeeet
Forked from MFreidank/infinite_dataloader.py
Created February 20, 2019 10:41
Show Gist options
  • Save skeeet/5eda9b294442346dc5f15c8764ee6a50 to your computer and use it in GitHub Desktop.
Save skeeet/5eda9b294442346dc5f15c8764ee6a50 to your computer and use it in GitHub Desktop.
A pytorch DataLoader that generates an unbounded/infinite number of minibatches from the dataset.
from torch.utils.data import DataLoader
class InfiniteDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Initialize an iterator over the dataset.
self.dataset_iterator = super().__iter__()
def __iter__(self):
return self
def __next__(self):
try:
batch = next(self.dataset_iterator)
except StopIteration:
# Dataset exhausted, use a new fresh iterator.
self.dataset_iterator = super().__iter__()
batch = next(self.dataset_iterator)
return batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment