Skip to content

Instantly share code, notes, and snippets.

@martenlienen
Created July 7, 2023 13:01
Show Gist options
  • Save martenlienen/4a0ee62d83f281ecb45be81145546613 to your computer and use it in GitHub Desktop.
Save martenlienen/4a0ee62d83f281ecb45be81145546613 to your computer and use it in GitHub Desktop.
Implementation of the [alias method](https://en.wikipedia.org/wiki/Alias_method) for sampling from a categorical distribution in constant time
import heapq
import matplotlib.pyplot as pp
import numpy as np
rng = np.random.default_rng()
k = 20 # Number of classes
n = 10_000_000 # Number of samples
cat = rng.uniform(size=k) # Categorical distribution (not normalized)
### Preprocessing
w = cat.copy()
largest = [(-w[i], i) for i in range(k)]
heapq.heapify(largest)
smallest = [(w[i], i) for i in range(k)]
heapq.heapify(smallest)
target = w.sum() / k
idx = np.stack((np.arange(k), np.arange(k)))
while True:
_, l = heapq.heappop(largest)
_, s = heapq.heappop(smallest)
if w[l] <= target:
break
w[l] -= target - w[s]
idx[1, s] = l
if w[l] > target:
heapq.heappush(largest, (-w[l], l))
elif w[l] < target:
heapq.heappush(smallest, (w[l], l))
### Sampling
bin_choice = rng.choice(k, size=n)
samples = idx[:, bin_choice][(rng.uniform(size=n) * target > w[bin_choice]).astype(int), np.arange(n)]
### Let's look at some samples
pp.bar(np.arange(k), cat / cat.sum(), lw=0, width=0.5, label="weights")
bins, counts = np.unique(samples, return_counts=True)
pp.bar(bins + 0.5, counts / counts.sum(), lw=0, alpha=0.5, width=0.5, label="samples")
pp.legend()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment