Skip to content

Instantly share code, notes, and snippets.

@simonmoesorensen
Last active October 5, 2023 10:04
Show Gist options
  • Save simonmoesorensen/ac590e8e25ac8b1c322519d2d8c73676 to your computer and use it in GitHub Desktop.
Save simonmoesorensen/ac590e8e25ac8b1c322519d2d8c73676 to your computer and use it in GitHub Desktop.
Creates a pytorch sampler that samples classes evenly. Utilizes vectorization and pytorch dataloaders to efficiently calculate weights
import torch
from torch.utils.data import DataLoader, sampler
from torchvision import datasets
def make_weights_for_balanced_classes(images, nclasses, batch_size):
"""
Adapted from https://gist.github.com/srikarplus/15d7263ae2c82e82fe194fc94321f34e
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
count = torch.zeros(nclasses).to(device)
loader = DataLoader(images, batch_size=batch_size, num_workers=num_workers)
for _, label in tqdm(loader, desc="Counting classes"):
label = label.to(device=device)
idx, counts = label.unique(return_counts=True)
count[idx] += counts
N = count.sum()
weight_per_class = N / count
weight = torch.zeros(len(images)).to(device)
for i, (img, label) in tqdm(enumerate(loader), desc="Apply weights", total=len(loader)):
idx = torch.arange(0, img.shape[0]) + (i * batch_size)
idx = idx.to(dtype=torch.long, device=device)
weight[idx] = weight_per_class[label]
return weight
# Train set
dataset_train = datasets.ImageFolder(traindir)
# For unbalanced dataset we create a weighted sampler
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes), args.batch_size)
sampler = sampler.WeightedRandomSampler(weights, len(weights))
train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle = False,
sampler = sampler, num_workers=args.workers, pin_memory=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment