Skip to content

Instantly share code, notes, and snippets.

@ottonemo
Created March 11, 2020 12:14
Show Gist options
  • Save ottonemo/cb64d0d259a1ab60b38d62a86554bdf7 to your computer and use it in GitHub Desktop.
Save ottonemo/cb64d0d259a1ab60b38d62a86554bdf7 to your computer and use it in GitHub Desktop.
class IDS(torch.utils.data.IterableDataset):
def genxy(self):
for _ in range(10):
xs, ys = [], []
for batch_idx in range(5):
xs.append(torch.zeros(2, 2) + batch_idx)
ys.append(torch.tensor(batch_idx))
yield torch.stack(xs), torch.stack(ys)
def __iter__(self):
return iter(self.genxy())
idl = torch.utils.data.DataLoader(IDS())
for x, y in idl:
print(x.shape, y.shape)
# will add unit batch dimension for some reason, print output:
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5])
class IDS(torch.utils.data.IterableDataset):
def genxy(self):
for _ in range(10):
xs, ys = [], []
for batch_idx in range(5):
xs.append(torch.zeros(2, 2) + batch_idx)
ys.append(torch.tensor(batch_idx))
yield xs, ys
def __iter__(self):
return iter(self.genxy())
idl = torch.utils.data.DataLoader(IDS(), collate_fn=lambda x: (torch.stack(x[0][0]), torch.stack(x[0][1])))
for x, y in idl:
print(x.shape, y.shape)
# works fine but needs collate_fn, print output:
# torch.Size([5, 2, 2]) torch.Size([5])
# torch.Size([5, 2, 2]) torch.Size([5])
# torch.Size([5, 2, 2]) torch.Size([5])
# torch.Size([5, 2, 2]) torch.Size([5])
# torch.Size([5, 2, 2]) torch.Size([5])
# torch.Size([5, 2, 2]) torch.Size([5])
# torch.Size([5, 2, 2]) torch.Size([5])
# torch.Size([5, 2, 2]) torch.Size([5])
# torch.Size([5, 2, 2]) torch.Size([5])
# torch.Size([5, 2, 2]) torch.Size([5])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment