Last active
October 16, 2023 13:36
-
-
Save MFreidank/821cc87b012c53fade03b0c7aba13958 to your computer and use it in GitHub Desktop.
A pytorch DataLoader that generates an unbounded/infinite number of minibatches from the dataset.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import DataLoader | |
class InfiniteDataLoader(DataLoader): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# Initialize an iterator over the dataset. | |
self.dataset_iterator = super().__iter__() | |
def __iter__(self): | |
return self | |
def __next__(self): | |
try: | |
batch = next(self.dataset_iterator) | |
except StopIteration: | |
# Dataset exhausted, use a new fresh iterator. | |
self.dataset_iterator = super().__iter__() | |
batch = next(self.dataset_iterator) | |
return batch |
I have a similar code, but once the generator is exhausted it keeps on creating new generators.
Code in the training loop:
try: # to avoid crashes due to exhausted generator
# Samples the batch
clouds, tokens = next(generator)
except StopIteration:
# restart the generator if the previous generator is exhausted.
# generator = iter(trainloader)
generator = dyn_DataLoader.make_generator()
clouds, tokens = next(generator)
print("****** New Generator ******")
Functions of dyn_DataLoader:
class DynamicDataLoader:
"""A dataloader that adapts to the gpu memory"""
def __init__(self, dataset, gpu_batch_size=2, target_batch_size=1024, adapt=True, **kwargs):
"""Initializes the dataloader with a small batch size.
Args:
dataset: torch.utils.data.Dataset.
gpu_batch_size: int, a guess of batch size that fits in gpu.
target_batch_size: int, the batch size you would like to have. There is no limit for this number,
just keep in mind that if too large the number of spoof steps may be too large and may slow down
training runtime.
adapt: bool, if batch size adaption should be performed.
kwargs: arguments for torch.utils.data.DataLoader().
"""
self.dataset = dataset
self.gpu_batch_size = int(gpu_batch_size)
self.target_batch_size = target_batch_size
self.kwargs = kwargs
self.adapt = adapt
self.init_dataloader()
self.generator = self.make_generator()
self.repeat_spoof = self.spoofing_repeats()
def init_dataloader(self):
"""Initializes pytorch dataloader"""
self._dataloader = DataLoader(dataset=self.dataset,
batch_size=self.gpu_batch_size,
**self.kwargs)
def increase_gpu_batch(self):
"""Increases gpu batch size"""
self.gpu_batch_size *= 2
def decrease_gpu_batch(self):
"""Decreases gpu batch size"""
self.gpu_batch_size = int(self.gpu_batch_size / 2)
def spoofing_repeats(self):
"""returns the number of repeats necessary for batch spoofing"""
return max(1, int(self.target_batch_size / self.gpu_batch_size))
def make_generator(self):
"""returns a generator (iterable) from the dataloader"""
return iter(self._dataloader)
Any ideas of what could be wrong?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is a subclass of
torch.DataLoader
that does not stop producing minibatches after the datasetis consumed but is a (potentially unbounded) generator of minibatches.
Powerful e.g. in combination with
itertools.islice
:This extracts
100
minibatches from the dataset, even if the dataset has less thanbatch_size * num_minibatches = 3000
datapoints.Supports all arguments of
torch.DataLoader
.