Skip to content

Instantly share code, notes, and snippets.

@srikarplus
Created December 1, 2018 08:17
Show Gist options
  • Save srikarplus/15d7263ae2c82e82fe194fc94321f34e to your computer and use it in GitHub Desktop.
Save srikarplus/15d7263ae2c82e82fe194fc94321f34e to your computer and use it in GitHub Desktop.
Stratified Sampling in Pytorch
def make_weights_for_balanced_classes(images, nclasses):
count = [0] * nclasses
for item in images:
count[item[1]] += 1
weight_per_class = [0.] * nclasses
N = float(sum(count))
for i in range(nclasses):
weight_per_class[i] = N/float(count[i])
weight = [0] * len(images)
for idx, val in enumerate(images):
weight[idx] = weight_per_class[val[1]]
return weight
# And after this, use it in the next way:
dataset_train = datasets.ImageFolder(traindir)
# For unbalanced dataset we create a weighted sampler
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes))
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle = True,
sampler = sampler, num_workers=args.workers, pin_memory=True)
@keshik6
Copy link

keshik6 commented Jul 2, 2019

Hi,
Thanks for the code sample. But sampler option is mutually exclusive with shuffle option. So need to set shuffle=False when using sampler.

@simonmoesorensen
Copy link

simonmoesorensen commented Jul 15, 2022

For those who need an implementation for large datasets:
https://gist.github.com/simonmoesorensen/ac590e8e25ac8b1c322519d2d8c73676

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment