Last active
March 14, 2023 23:14
-
-
Save andres-fr/7a2a7f622b446353db21e7dc96504193 to your computer and use it in GitHub Desktop.
PyTorch dataset sampler that retrieves a subset of the indices, with the possibility of deterministic and balanced behaviour.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import random | |
class SubsetSampler(torch.utils.data.Sampler): | |
""" | |
Like ``torch.utils.data.SubsetRandomsampler``, but without the random, | |
and with the possibility of balanced sampling. Samples a subset of the | |
given dataset from a given list of indices, without replacement. | |
Usage example:: | |
sampler = SubsetSampler.get_balanced(mnist_dataset, size=200) | |
dl = DataLoader(train_ds, batch_size=1, sampler=sampler) | |
_, labels = zip(*dl) | |
testdict = defaultdict(int) | |
for lbl in labels: | |
testdict[lbl.item()] += 1 | |
print(len(labels)) # should be 200 | |
print(testdict) # should be {0:20, 1:20, 2:20, ...} | |
""" | |
def __init__(self, *indices): | |
""" | |
:param indices: Integer indices for the sampling. | |
""" | |
self.indices = indices | |
def __len__(self): | |
""" | |
""" | |
return len(self.indices) | |
def __iter__(self): | |
""" | |
""" | |
for idx in self.indices: | |
yield idx | |
@classmethod | |
def get_balanced(cls, dataset, size=100, random=False): | |
""" | |
given a ``dataset`` that yields ``dataset[idx] = (data, label)``, | |
where labels are hashable, this method returns a ``SubsetSampler`` | |
with ``size`` indexes, such that they are balanced among classes. | |
This requires that the dataset can be integer-divided by ``size`` | |
across the number of its classes, and that each class has at least | |
``size / num_classes`` elements. Indexes are retrieved in ascending | |
order. If ``random`` is false, the lowest indexes for each class will | |
be gathered, making it deterministic. | |
""" | |
assert len(dataset) >= size, "Size can't be larger than dataset!" | |
# group all data indexes by label, and optionally shuffle them | |
histogram = defaultdict(list) | |
for idx, (_, lbl) in enumerate(dataset): | |
histogram[lbl].append(idx) | |
if random: | |
for v in histogram.values(): | |
random.shuffle(v) | |
# sanity check: | |
class_sizes = {k: len(v) for k, v in histogram.items()} | |
num_classes = len(histogram) | |
entries_per_class, rest = divmod(size, num_classes) | |
assert rest == 0, "Please choose a size divisible by num_classes!" | |
assert all(v >= entries_per_class for v in class_sizes.values()), \ | |
f"Not all classes have enough elements! {class_sizes}" | |
# now we can gather the balanced indexes into the sampler | |
idxs = sorted(sum((v[:entries_per_class] | |
for v in histogram.values()), [])) | |
sampler = cls(*idxs) | |
return sampler |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment