Skip to content

Instantly share code, notes, and snippets.

@lzhbrian
Forked from MFreidank/infinite_dataloader.py
Created November 13, 2020 04:06
Show Gist options
  • Save lzhbrian/fa42b9a1636af27c620794b201d2d2b6 to your computer and use it in GitHub Desktop.
Save lzhbrian/fa42b9a1636af27c620794b201d2d2b6 to your computer and use it in GitHub Desktop.
A pytorch DataLoader that generates an unbounded/infinite number of minibatches from the dataset.
from torch.utils.data import DataLoader
class InfiniteDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Initialize an iterator over the dataset.
self.dataset_iterator = super().__iter__()
def __iter__(self):
return self
def __next__(self):
try:
batch = next(self.dataset_iterator)
except StopIteration:
# Dataset exhausted, use a new fresh iterator.
self.dataset_iterator = super().__iter__()
batch = next(self.dataset_iterator)
return batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment