Skip to content

Instantly share code, notes, and snippets.

@ilkarman
Created October 15, 2019 10:11
Show Gist options
  • Save ilkarman/fe45c586ef10c4cb8e52b3a78b6ac854 to your computer and use it in GitHub Desktop.
Save ilkarman/fe45c586ef10c4cb8e52b3a78b6ac854 to your computer and use it in GitHub Desktop.
import torch
import torchvision
import torch.utils.data
import random
import numpy as np
from torch.utils.data import TensorDataset
# https://github.com/galatolofederico/pytorch-balanced-batch/blob/master/sampler.py
class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
def __init__(self, dataset, labels=None):
self.labels = labels
self.dataset = dict()
self.balanced_max = 0
# Save all the indices for all the classes
for idx in range(0, len(dataset)):
label = self._get_label(dataset, idx)
if label not in self.dataset:
self.dataset[label] = list()
self.dataset[label].append(idx)
self.balanced_max = len(self.dataset[label]) \
if len(self.dataset[label]) > self.balanced_max else self.balanced_max
# Oversample the classes with fewer elements than the max
for label in self.dataset:
while len(self.dataset[label]) < self.balanced_max:
self.dataset[label].append(random.choice(self.dataset[label]))
self.keys = list(self.dataset.keys())
self.currentkey = 0
self.indices = [-1]*len(self.keys)
def __iter__(self):
while self.indices[self.currentkey] < self.balanced_max - 1:
self.indices[self.currentkey] += 1
yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]]
self.currentkey = (self.currentkey + 1) % len(self.keys)
self.indices = [-1]*len(self.keys)
def _get_label(self, dataset, idx, labels = None):
if self.labels is not None:
return self.labels[idx].item()
def __len__(self):
return self.balanced_max*len(self.keys)
# Create unbalanced data-set
X = torch.Tensor(np.random.rand(100,2))
y = torch.Tensor(np.concatenate((np.ones(98), np.zeros(2))))
# Use sampler
train_loader = torch.utils.data.DataLoader(
TensorDataset(X,y),
sampler=BalancedBatchSampler(X, y),
batch_size=20)
# Test
for data, labels in train_loader:
print(labels)
# tensor([1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.])
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment