Created
May 8, 2018 07:02
-
-
Save wassname/315a50aa5c221cc4b30249568def10f3 to your computer and use it in GitHub Desktop.
NumpyDataset for pytorch (like tensordataset)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.data | |
class NumpyDataset(torch.utils.data.Dataset): | |
"""Dataset wrapping arrays. | |
Each sample will be retrieved by indexing array along the first dimension. | |
Arguments: | |
*arrays (numpy.array): arrays that have the same size of the first dimension. | |
""" | |
def __init__(self, *arrays): | |
assert all(arrays[0].shape[0] == array.shape[0] for array in arrays) | |
self.arrays = arrays | |
def __getitem__(self, index): | |
return tuple(array[index].compute() for array in self.arrays) | |
def __len__(self): | |
return self.arrays[0].shape[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment