Created
March 11, 2020 12:14
-
-
Save ottonemo/cb64d0d259a1ab60b38d62a86554bdf7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class IDS(torch.utils.data.IterableDataset): | |
def genxy(self): | |
for _ in range(10): | |
xs, ys = [], [] | |
for batch_idx in range(5): | |
xs.append(torch.zeros(2, 2) + batch_idx) | |
ys.append(torch.tensor(batch_idx)) | |
yield torch.stack(xs), torch.stack(ys) | |
def __iter__(self): | |
return iter(self.genxy()) | |
idl = torch.utils.data.DataLoader(IDS()) | |
for x, y in idl: | |
print(x.shape, y.shape) | |
# will add unit batch dimension for some reason, print output: | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
# torch.Size([1, 5, 2, 2]) torch.Size([1, 5]) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class IDS(torch.utils.data.IterableDataset): | |
def genxy(self): | |
for _ in range(10): | |
xs, ys = [], [] | |
for batch_idx in range(5): | |
xs.append(torch.zeros(2, 2) + batch_idx) | |
ys.append(torch.tensor(batch_idx)) | |
yield xs, ys | |
def __iter__(self): | |
return iter(self.genxy()) | |
idl = torch.utils.data.DataLoader(IDS(), collate_fn=lambda x: (torch.stack(x[0][0]), torch.stack(x[0][1]))) | |
for x, y in idl: | |
print(x.shape, y.shape) | |
# works fine but needs collate_fn, print output: | |
# torch.Size([5, 2, 2]) torch.Size([5]) | |
# torch.Size([5, 2, 2]) torch.Size([5]) | |
# torch.Size([5, 2, 2]) torch.Size([5]) | |
# torch.Size([5, 2, 2]) torch.Size([5]) | |
# torch.Size([5, 2, 2]) torch.Size([5]) | |
# torch.Size([5, 2, 2]) torch.Size([5]) | |
# torch.Size([5, 2, 2]) torch.Size([5]) | |
# torch.Size([5, 2, 2]) torch.Size([5]) | |
# torch.Size([5, 2, 2]) torch.Size([5]) | |
# torch.Size([5, 2, 2]) torch.Size([5]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment