Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active July 24, 2018 01:55
Show Gist options
  • Save wassname/ca07651dea84eed22e38808431b99837 to your computer and use it in GitHub Desktop.
Save wassname/ca07651dea84eed22e38808431b99837 to your computer and use it in GitHub Desktop.
torch SequentialRandomSampler
from torch.utils.data.sampler import Sampler
import itertools
class SequentialRandomSampler(Sampler):
"""Samples elements sequentially, starting from a random location.
For when you want to sequentially sampled a random subset
Usage:
loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=batch_size,
sampler=SequentialRandomSampler(dataset_train)
)
for data in iter(loader):
print(data.sum()) # Different each time you run it
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
data_len = len(self.data_source)
folder = np.random.randint(0, data_len)
indices = itertools.chain(range(folder, len(dataset_train)), range(0, folder))
return iter(indices)
def __len__(self):
return len(self.data_source)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment