Skip to content

Instantly share code, notes, and snippets.

@MFreidank
Last active October 16, 2023 13:36
Show Gist options
  • Save MFreidank/821cc87b012c53fade03b0c7aba13958 to your computer and use it in GitHub Desktop.
Save MFreidank/821cc87b012c53fade03b0c7aba13958 to your computer and use it in GitHub Desktop.
A pytorch DataLoader that generates an unbounded/infinite number of minibatches from the dataset.
from torch.utils.data import DataLoader
class InfiniteDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Initialize an iterator over the dataset.
self.dataset_iterator = super().__iter__()
def __iter__(self):
return self
def __next__(self):
try:
batch = next(self.dataset_iterator)
except StopIteration:
# Dataset exhausted, use a new fresh iterator.
self.dataset_iterator = super().__iter__()
batch = next(self.dataset_iterator)
return batch
@MFreidank
Copy link
Author

MFreidank commented Apr 26, 2018

This is a subclass of torch.DataLoader that does not stop producing minibatches after the dataset
is consumed but is a (potentially unbounded) generator of minibatches.

Powerful e.g. in combination with itertools.islice:

from itertools import islice
num_minibatches = 100
for (x_batch, y_batch) in islice(InfiniteDataLoader(dataset, batch_size=30, shuffle=True), num_minibatches):
    # Do something with the given minibatch

This extracts 100 minibatches from the dataset, even if the dataset has less than batch_size * num_minibatches = 3000 datapoints.

Supports all arguments of torch.DataLoader.

@VictorZuanazzi
Copy link

VictorZuanazzi commented Nov 26, 2019

I have a similar code, but once the generator is exhausted it keeps on creating new generators.

Code in the training loop:

  try:  # to avoid crashes due to exhausted generator
                # Samples the batch
                clouds, tokens = next(generator)
            except StopIteration:
                # restart the generator if the previous generator is exhausted.
                # generator = iter(trainloader)
                generator = dyn_DataLoader.make_generator()
                clouds, tokens = next(generator)
                print("****** New Generator ******")

Functions of dyn_DataLoader:

class DynamicDataLoader:
    """A dataloader that adapts to the gpu memory"""
    def __init__(self, dataset, gpu_batch_size=2, target_batch_size=1024,  adapt=True, **kwargs):
        """Initializes the dataloader with a small batch size.
        Args:
            dataset: torch.utils.data.Dataset.
            gpu_batch_size: int, a guess of batch size that fits in gpu.
            target_batch_size: int, the batch size you would like to have. There is no limit for this number,
                just keep in mind that if too large the number of spoof steps may be too large and may slow down
                training runtime.
            adapt: bool, if batch size adaption should be performed.
            kwargs: arguments for torch.utils.data.DataLoader().
             """

        self.dataset = dataset
        self.gpu_batch_size = int(gpu_batch_size)
        self.target_batch_size = target_batch_size
        self.kwargs = kwargs
        self.adapt = adapt

        self.init_dataloader()
        self.generator = self.make_generator()
        self.repeat_spoof = self.spoofing_repeats()

    def init_dataloader(self):
        """Initializes pytorch dataloader"""
        self._dataloader = DataLoader(dataset=self.dataset,
                                      batch_size=self.gpu_batch_size,
                                      **self.kwargs)
    def increase_gpu_batch(self):
        """Increases gpu batch size"""
        self.gpu_batch_size *= 2

    def decrease_gpu_batch(self):
        """Decreases gpu batch size"""
        self.gpu_batch_size = int(self.gpu_batch_size / 2)

    def spoofing_repeats(self):
        """returns the number of repeats necessary for batch spoofing"""
        return max(1, int(self.target_batch_size / self.gpu_batch_size))

    def make_generator(self):
        """returns a generator (iterable) from the dataloader"""
        return iter(self._dataloader)

Any ideas of what could be wrong?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment