Skip to content

Instantly share code, notes, and snippets.

@GCBallesteros
Last active November 11, 2022 11:52
Show Gist options
  • Save GCBallesteros/05ca61456f7cf0a319d5736c553cde19 to your computer and use it in GitHub Desktop.
Save GCBallesteros/05ca61456f7cf0a319d5736c553cde19 to your computer and use it in GitHub Desktop.
Batched Dataloader in pytorch
import numpy as np
import torch.utils.data as data
class SimpleDataset(data.Dataset):
def __init__(self):
super().__init__()
self.data = np.random.randn(1000, 5)
def __getitem__(self, idxs):
return idxs, self.data[idxs, :]
def __len__(self):
return self.data.shape[0]
# https://github.com/pytorch/pytorch/issues/26957
if __name__ == "__main__":
drop_last = False
batch_size = 234
dataset = SimpleDataset()
loader = data.DataLoader(
dataset,
batch_size=None,
sampler=data.BatchSampler(
data.RandomSampler(dataset), batch_size, drop_last=drop_last
),
)
idxs_accum = []
sum = 0
for batch_idxs, batch_values in loader:
idxs_accum += batch_idxs
sum += batch_values.numpy().sum()
# Tests
if not drop_last:
assert np.isclose(dataset.data.sum(), sum)
if drop_last:
assert len(idxs_accum) == (len(dataset) // batch_size) * batch_size
assert len(set(idxs_accum)) == (len(dataset) // batch_size) * batch_size
else:
assert len(idxs_accum) == len(dataset)
assert len(set(idxs_accum)) == len(dataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment