Skip to content

Instantly share code, notes, and snippets.

@kampersanda
Last active June 2, 2023 07:58
Show Gist options
  • Save kampersanda/8df2c2583457aa18a3efb6c55234ab67 to your computer and use it in GitHub Desktop.
Save kampersanda/8df2c2583457aa18a3efb6c55234ab67 to your computer and use it in GitHub Desktop.
Walker's Alias Method in Python
import collections
import random
class WalkerAlias:
"""
https://qiita.com/kaityo256/items/1656597198cbfeb7328c
"""
def __init__(self, weights: list[float]):
if len(weights) == 0:
raise ValueError('weights must be non-empty')
mean = sum(weights) / len(weights)
weights = [w / mean for w in weights]
small: collections.deque[int] = collections.deque()
large: collections.deque[int] = collections.deque()
for i, w in enumerate(weights):
if w <= 1.0:
small.append(i)
else:
large.append(i)
index = list(range(len(weights)))
while len(large) > 0 and len(small) > 0:
j = small.pop()
k = large[-1]
index[j] = k
weights[k] -= 1.0 - weights[j]
if weights[k] <= 1.0:
small.append(k)
large.pop()
self._index: list[int] = index
self._threshold: list[float] = weights
def sample(self) -> int:
r = random.randint(0, len(self._index) - 1)
if self._threshold[r] > random.random():
return r
else:
return self._index[r]
if __name__ == "__main__":
weights = [3.0, 6.0, 9.0, 1.0, 2.0, 3.0, 7.0, 7.0, 4.0, 8.0]
alias = WalkerAlias(weights)
print(alias._index)
print(alias._threshold)
trial = 1000000
tinv = 1.0 / trial
result = [0.0] * len(weights)
for _ in range(trial):
result[alias.sample()] += tinv
n = sum(weights)
for i, r in enumerate(result):
print(f"{i}: {r * n}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment