Skip to content

Instantly share code, notes, and snippets.

@wassname
Created May 8, 2018 07:02
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/315a50aa5c221cc4b30249568def10f3 to your computer and use it in GitHub Desktop.
Save wassname/315a50aa5c221cc4b30249568def10f3 to your computer and use it in GitHub Desktop.
NumpyDataset for pytorch (like tensordataset)
import torch.utils.data
class NumpyDataset(torch.utils.data.Dataset):
"""Dataset wrapping arrays.
Each sample will be retrieved by indexing array along the first dimension.
Arguments:
*arrays (numpy.array): arrays that have the same size of the first dimension.
"""
def __init__(self, *arrays):
assert all(arrays[0].shape[0] == array.shape[0] for array in arrays)
self.arrays = arrays
def __getitem__(self, index):
return tuple(array[index].compute() for array in self.arrays)
def __len__(self):
return self.arrays[0].shape[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment