Skip to content

Instantly share code, notes, and snippets.

@codezakh
Created June 26, 2020 19:24
Show Gist options
  • Save codezakh/0a6ad34ff8baf7c33a8d591467a2cab0 to your computer and use it in GitHub Desktop.
Save codezakh/0a6ad34ff8baf7c33a8d591467a2cab0 to your computer and use it in GitHub Desktop.
class CustomBatchSampler(Sampler):
def __init__(self, batch_size, dataset):
self.sampler = SequentialSampler(dataset)
self.batch_size = batch_size
self.drop_last = False
self.dataset = dataset
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
current_class = self.dataset[idx][-1]
try:
next_class = self.dataset[idx+1][-1]
except IndexError:
next_class = current_class
if batch and next_class != current_class:
yield batch
batch = []
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
# Use like this:
dataset = ImageFolder(...)
sampler = CustomBatchSampler(batch_size=16, dataset=dataset)
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment