Skip to content

Instantly share code, notes, and snippets.

@andres-fr
Last active March 14, 2023 23:14
Show Gist options
  • Save andres-fr/7a2a7f622b446353db21e7dc96504193 to your computer and use it in GitHub Desktop.
Save andres-fr/7a2a7f622b446353db21e7dc96504193 to your computer and use it in GitHub Desktop.
PyTorch dataset sampler that retrieves a subset of the indices, with the possibility of deterministic and balanced behaviour.
from collections import defaultdict
import random
class SubsetSampler(torch.utils.data.Sampler):
"""
Like ``torch.utils.data.SubsetRandomsampler``, but without the random,
and with the possibility of balanced sampling. Samples a subset of the
given dataset from a given list of indices, without replacement.
Usage example::
sampler = SubsetSampler.get_balanced(mnist_dataset, size=200)
dl = DataLoader(train_ds, batch_size=1, sampler=sampler)
_, labels = zip(*dl)
testdict = defaultdict(int)
for lbl in labels:
testdict[lbl.item()] += 1
print(len(labels)) # should be 200
print(testdict) # should be {0:20, 1:20, 2:20, ...}
"""
def __init__(self, *indices):
"""
:param indices: Integer indices for the sampling.
"""
self.indices = indices
def __len__(self):
"""
"""
return len(self.indices)
def __iter__(self):
"""
"""
for idx in self.indices:
yield idx
@classmethod
def get_balanced(cls, dataset, size=100, random=False):
"""
given a ``dataset`` that yields ``dataset[idx] = (data, label)``,
where labels are hashable, this method returns a ``SubsetSampler``
with ``size`` indexes, such that they are balanced among classes.
This requires that the dataset can be integer-divided by ``size``
across the number of its classes, and that each class has at least
``size / num_classes`` elements. Indexes are retrieved in ascending
order. If ``random`` is false, the lowest indexes for each class will
be gathered, making it deterministic.
"""
assert len(dataset) >= size, "Size can't be larger than dataset!"
# group all data indexes by label, and optionally shuffle them
histogram = defaultdict(list)
for idx, (_, lbl) in enumerate(dataset):
histogram[lbl].append(idx)
if random:
for v in histogram.values():
random.shuffle(v)
# sanity check:
class_sizes = {k: len(v) for k, v in histogram.items()}
num_classes = len(histogram)
entries_per_class, rest = divmod(size, num_classes)
assert rest == 0, "Please choose a size divisible by num_classes!"
assert all(v >= entries_per_class for v in class_sizes.values()), \
f"Not all classes have enough elements! {class_sizes}"
# now we can gather the balanced indexes into the sampler
idxs = sorted(sum((v[:entries_per_class]
for v in histogram.values()), []))
sampler = cls(*idxs)
return sampler
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment