Skip to content

Instantly share code, notes, and snippets.

@xLaszlo
Created December 19, 2020 12:11
Show Gist options
  • Save xLaszlo/ecb13c90e6cb1d1e6f911b321f638584 to your computer and use it in GitHub Desktop.
Save xLaszlo/ecb13c90e6cb1d1e6f911b321f638584 to your computer and use it in GitHub Desktop.
Fast alias sampling using Vose's initialisation
# from http://www.keithschwarz.com/darts-dice-coins/ by Keith Schwarz (htiek@cs.stanford.edu)
# and https://github.com/asmith26/Vose-Alias-Method/blob/master/vose_sampler/vose_sampler.py
class VoseAlias:
def __init__(self, probs, seed=42, rng=None):
self.N = len(probs)
self.rng = rng or np.random.default_rng(seed)
self.aliases = np.zeros(self.N, dtype=np.int64)
self.probs = np.zeros(self.N)
scaledProbs = probs * self.N
smallInds = [i for i, p in enumerate(scaledProbs) if p < 1.0]
largeInds = [i for i, p in enumerate(scaledProbs) if p >= 1.0]
while len(smallInds) > 0 and len(largeInds) > 0:
smallInd = smallInds.pop()
largeInd = largeInds.pop()
self.probs[smallInd] = scaledProbs[smallInd]
self.aliases[smallInd] = largeInd
# Maybe this gets optimised out
# Move -1.0 to the next row and change to `if remainingProb < 2.0:` below
remainingProb = (scaledProbs[largeInd] + scaledProbs[smallInd]) - 1.0
scaledProbs[largeInd] = remainingProb
if remainingProb < 1.0:
smallInds.append(largeInd)
else:
largeInds.append(largeInd)
for smallInd in smallInds:
self.probs[smallInd] = 1.0
for largeInd in largeInds:
self.probs[largeInd] = 1.0
def sample(self, n=None):
if n is None:
res = self.rng.integers(self.N)
return res if self.probs[res] >= self.rng.random() else self.aliases[res]
res = self.rng.integers(low=0, high=self.N, size=n)
return np.where(self.probs[res] >= self.rng.random(size=n), res, self.aliases[res])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment