Last active
July 24, 2018 01:55
-
-
Save wassname/ca07651dea84eed22e38808431b99837 to your computer and use it in GitHub Desktop.
torch SequentialRandomSampler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data.sampler import Sampler | |
import itertools | |
class SequentialRandomSampler(Sampler): | |
"""Samples elements sequentially, starting from a random location. | |
For when you want to sequentially sampled a random subset | |
Usage: | |
loader = torch.utils.data.DataLoader( | |
dataset_train, | |
batch_size=batch_size, | |
sampler=SequentialRandomSampler(dataset_train) | |
) | |
for data in iter(loader): | |
print(data.sum()) # Different each time you run it | |
Arguments: | |
data_source (Dataset): dataset to sample from | |
""" | |
def __init__(self, data_source): | |
self.data_source = data_source | |
def __iter__(self): | |
data_len = len(self.data_source) | |
folder = np.random.randint(0, data_len) | |
indices = itertools.chain(range(folder, len(dataset_train)), range(0, folder)) | |
return iter(indices) | |
def __len__(self): | |
return len(self.data_source) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment