Skip to content

Instantly share code, notes, and snippets.

View simonmoesorensen's full-sized avatar

Simon Moe Sørensen simonmoesorensen

View GitHub Profile
@simonmoesorensen
simonmoesorensen / pytorch_stratified_sampling.py
Last active October 5, 2023 10:04
Creates a pytorch sampler that samples classes evenly. Utilizes vectorization and pytorch dataloaders to efficiently calculate weights
import torch
from torch.utils.data import DataLoader, sampler
from torchvision import datasets
def make_weights_for_balanced_classes(images, nclasses, batch_size):
"""
Adapted from https://gist.github.com/srikarplus/15d7263ae2c82e82fe194fc94321f34e
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")