Last active
March 31, 2024 00:31
A Lua implementation of the Alias Method, for sampling from an arbitrary distribution.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local alias_table = {} | |
function alias_table:new(weights) | |
local total = 0 | |
for _,v in ipairs(weights) do | |
assert(v >= 0, "all weights must be non-negative") | |
total = total + v | |
end | |
assert(total > 0, "total weight must be positive") | |
local normalize = #weights / total | |
local norm = {} | |
local small_stack = {} | |
local big_stack = {} | |
for i,w in ipairs(weights) do | |
norm[i] = w * normalize | |
if norm[i] < 1 then | |
table.insert(small_stack, i) | |
else | |
table.insert(big_stack, i) | |
end | |
end | |
local prob = {} | |
local alias = {} | |
while small_stack[1] and big_stack[1] do -- both non-empty | |
small = table.remove(small_stack) | |
large = table.remove(big_stack) | |
prob[small] = norm[small] | |
alias[small] = large | |
norm[large] = norm[large] + norm[small] - 1 | |
if norm[large] < 1 then | |
table.insert(small_stack, large) | |
else | |
table.insert(big_stack, large) | |
end | |
end | |
for _, v in ipairs(big_stack) do prob[v] = 1 end | |
for _, v in ipairs(small_stack) do prob[v] = 1 end | |
self.__index = self | |
return setmetatable({alias=alias, prob=prob, n=#weights}, self) | |
end | |
function alias_table:__call() | |
local index = math.random(self.n) | |
return math.random() < self.prob[index] and index or self.alias[index] | |
end | |
return alias_table | |
--[[ -- usage: | |
alias_table = require"alias_table" | |
sample = alias_table:new{10, 20, 15, 2, 2.3, 130} -- assign weights for 1, 2, 3, 4, 5, 6 etc. | |
math.randomseed(os.time()); math.random(); math.random(); math.random(); | |
print(sample()) | |
print(sample()) | |
print(sample()) | |
print(sample()) | |
print(sample()) | |
--]] |
Public domain is fine
Hey, I have problem, ik its stupid question but i have created items and i just need help with error all weights must be non-negative
Hi- quick warning to people, I tested this code with 10 million iterations, and it yielded very incorrect results as to what the expected results should be
nvm, im an idiot. Thank you lol
I have tested this code with 10 million iterations, and can confirm it works on edge cases 👍
Thank you!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Can I use this under MIT or other open license?