Created
December 19, 2020 12:11
-
-
Save xLaszlo/ecb13c90e6cb1d1e6f911b321f638584 to your computer and use it in GitHub Desktop.
Fast alias sampling using Vose's initialisation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from http://www.keithschwarz.com/darts-dice-coins/ by Keith Schwarz (htiek@cs.stanford.edu) | |
# and https://github.com/asmith26/Vose-Alias-Method/blob/master/vose_sampler/vose_sampler.py | |
class VoseAlias: | |
def __init__(self, probs, seed=42, rng=None): | |
self.N = len(probs) | |
self.rng = rng or np.random.default_rng(seed) | |
self.aliases = np.zeros(self.N, dtype=np.int64) | |
self.probs = np.zeros(self.N) | |
scaledProbs = probs * self.N | |
smallInds = [i for i, p in enumerate(scaledProbs) if p < 1.0] | |
largeInds = [i for i, p in enumerate(scaledProbs) if p >= 1.0] | |
while len(smallInds) > 0 and len(largeInds) > 0: | |
smallInd = smallInds.pop() | |
largeInd = largeInds.pop() | |
self.probs[smallInd] = scaledProbs[smallInd] | |
self.aliases[smallInd] = largeInd | |
# Maybe this gets optimised out | |
# Move -1.0 to the next row and change to `if remainingProb < 2.0:` below | |
remainingProb = (scaledProbs[largeInd] + scaledProbs[smallInd]) - 1.0 | |
scaledProbs[largeInd] = remainingProb | |
if remainingProb < 1.0: | |
smallInds.append(largeInd) | |
else: | |
largeInds.append(largeInd) | |
for smallInd in smallInds: | |
self.probs[smallInd] = 1.0 | |
for largeInd in largeInds: | |
self.probs[largeInd] = 1.0 | |
def sample(self, n=None): | |
if n is None: | |
res = self.rng.integers(self.N) | |
return res if self.probs[res] >= self.rng.random() else self.aliases[res] | |
res = self.rng.integers(low=0, high=self.N, size=n) | |
return np.where(self.probs[res] >= self.rng.random(size=n), res, self.aliases[res]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment