Skip to content

Instantly share code, notes, and snippets.

@ozancaglayan
Created October 31, 2017 17:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ozancaglayan/8ff67259f73e18be078177f1ae5a7805 to your computer and use it in GitHub Desktop.
Save ozancaglayan/8ff67259f73e18be078177f1ae5a7805 to your computer and use it in GitHub Desktop.
WeightedBatchSampler for PyTorch
class WeightedBatchSampler(Sampler):
def __init__(self, n_elems, batch_size,
initial_p=None, epoch_p_reset=False):
self.n_elems = n_elems
self.batch_size = batch_size
self.epoch_p_reset = epoch_p_reset
self.n_batches = math.ceil(self.n_elems / self.batch_size)
if initial_p is None:
# Start with uniform probability for each sample
self.p = np.ones(self.n_elems) / self.n_elems
else:
self.p = p
def update_p(self, sample_idxs, sample_scores):
# Cumulatively update scores for samples
self.p[sample_idxs] += sample_scores
# Normalize probability distribution
self.p /= self.p.sum()
def __iter__(self):
ctr = 0
while ctr < self.n_batches:
bidxs = np.random.choice(self.n_elems, self.batch_size,
replace=False, p=self.p)
scores = (yield bidxs)
if scores is not None:
print(' Received scores')
self.update_p(bidxs, scores)
# Increment batch counter
ctr += 1
if self.epoch_p_reset:
self.p = np.ones(self.n_elems) / self.n_elems
def __len__(self):
"""Returns how many batches are inside."""
return self.n_batches
if __name__ == '__main__':
sampler = WeightedBatchSampler(100000, 32)
gen = iter(sampler)
for idx, batch in enumerate(gen):
# Flip a biased coin and update scores
if np.random.binomial(1, p=0.05):
print('Sending scores back to generator at iteration %d' % idx)
scores = np.random.randint(low=1, high=2, size=batch.size)
gen.send(scores)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment