Skip to content

Instantly share code, notes, and snippets.

Created October 3, 2017 11:50
Show Gist options
  • Save thomwolf/3d1b008336e7ec41a00ce723703ac843 to your computer and use it in GitHub Desktop.
Save thomwolf/3d1b008336e7ec41a00ce723703ac843 to your computer and use it in GitHub Desktop.
A pyTorch BatchSampler that enables large epochs on small datasets and balanced sampling from unbalanced datasets
class DeepMojiBatchSampler(object):
"""A Batch sampler that enables larger epochs on small datasets and
has upsampling functionality.
# Arguments:
y_in: Labels of the dataset.
batch_size: Batch size.
epoch_size: Number of samples in an epoch.
upsample: Whether upsampling should be done. This flag should only be
set on binary class problems.
seed: Random number generator seed.
# __iter__ output:
iterator of lists (batches) of indices in the dataset
def __init__(self, y_in, batch_size, epoch_size, upsample, seed):
self.batch_size = batch_size
self.epoch_size = epoch_size
self.upsample = upsample
if upsample:
# Should only be used on binary class problems
assert len(y_in.shape) == 1
neg = np.where(y_in.numpy() == 0)[0]
pos = np.where(y_in.numpy() == 1)[0]
assert epoch_size % 2 == 0
samples_pr_class = int(epoch_size / 2)
ind = range(len(y_in))
if not upsample:
# Randomly sample observations in a balanced way
self.sample_ind = np.random.choice(ind, epoch_size, replace=True)
# Randomly sample observations in a balanced way
sample_neg = np.random.choice(neg, samples_pr_class, replace=True)
sample_pos = np.random.choice(pos, samples_pr_class, replace=True)
concat_ind = np.concatenate((sample_neg, sample_pos), axis=0)
# Shuffle to avoid labels being in specific order
# (all negative then positive)
p = np.random.permutation(len(concat_ind))
self.sample_ind = concat_ind[p]
label_dist = np.mean(y_in.numpy()[self.sample_ind])
assert(label_dist > 0.45)
assert(label_dist < 0.55)
def __iter__(self):
# Hand-off data using batch_size
for i in range(int(self.epoch_size/self.batch_size)):
start = i * self.batch_size
end = min(start + self.batch_size, self.epoch_size)
yield self.sample_ind[start:end]
def __len__(self):
# Take care of the last (maybe incomplete) batch
return (self.epoch_size + self.batch_size - 1) // self.batch_size
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment