Skip to content

Instantly share code, notes, and snippets.

@teryror
Last active August 22, 2021 09:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save teryror/3d52a64a7081257503dd0787a47c3f21 to your computer and use it in GitHub Desktop.
Save teryror/3d52a64a7081257503dd0787a47c3f21 to your computer and use it in GitHub Desktop.
Const evaluatable Rust implementation of Vose's Alias Method
/// Const evaluatable Rust implementation of Vose's Alias Method, as described
/// by Keith Schwarz at https://www.keithschwarz.com/darts-dice-coins/
///
/// In brief, this is an O(n) precomputation, which allows sampling an arbitrary
/// finite probability distribution in O(1) time, by first simulating a fair
/// n-sided die, followed by a biased coin.
///
/// Because floating point arithmetic cannot be used in const functions, this is
/// built to operate on integer weights, rather than precomputed probabilities.
///
/// Where the standard Alias Method scales the probabilities by a factor of n
/// and uses 1 as a cutoff to partition them into large and small probabilites,
/// this finds the least common multiple of n and the total weight, scales up
/// the weights to match it, and uses the LCM divided by n as the threshold.
///
/// Unlike the original method, this approach is perfectly exact and numerically
/// stable; I only switch to fixed point arithmetic for the final probability
/// calculation, which introduces negligible rounding errors.
///
/// The implementation could be made much more elegant as more language features
/// become available in const fns, most notably for loops, panics, and arguments
/// of mutable reference types.
use rand::{Rng, thread_rng};
use rand::distributions::Distribution;
const fn gcd(mut a: u32, mut b: u32) -> u32 {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
const fn lcm(a: u32, b: u32) -> u32 {
(a * b) / gcd(a, b)
}
pub struct AliasTable<const N: usize> {
prob: [u32; N],
alias: [usize; N],
}
impl<const N: usize> AliasTable<N> {
pub const fn new(mut weights: [u32; N]) -> Self {
let mut prob = [0; N];
let mut alias = [0; N];
// Vec and similar data structures cannot be used in const fns, because
// only other const fns may be called, which may not take &mut arguments.
// So we have to use an ad-hoc, inline implementation for the work lists.
//
// These could have capacity N - 1, except the current state of const
// generics doesn't allow that.
let mut small = [0; N];
let mut small_count = 0;
let mut large = [0; N];
let mut large_count = 0;
let mut total_weight = 0;
let mut i = 0;
while i < N {
// TODO(const_panic): assert_ne!(weights[i], 0, "Weight at position {} is zero!", i);
let _ = 1 / weights[i];
total_weight += weights[i];
i += 1;
}
let rescaled_total = lcm(total_weight, N as u32);
let weight_factor = rescaled_total / total_weight;
let mut i = 0;
while i < N {
weights[i] *= weight_factor;
i += 1;
}
let weight_threshold = rescaled_total / (N as u32);
let mut i = 0;
while i < N {
if weights[i] < weight_threshold {
small[small_count] = i;
small_count += 1;
} else {
large[large_count] = i;
large_count += 1;
}
i += 1;
}
while small_count > 0 && large_count > 0 {
small_count -= 1;
let l = small[small_count];
large_count -= 1;
let g = large[large_count];
prob[l] = (((weights[l] as u64) << 32) / (weight_threshold as u64)) as u32;
alias[l] = g;
weights[g] -= weight_threshold - weights[l];
if weights[g] < weight_threshold {
small[small_count] = g;
small_count += 1;
} else {
large[large_count] = g;
large_count += 1;
}
}
while large_count > 0 {
large_count -= 1;
let g = large[large_count];
prob[g] = u32::MAX;
alias[g] = g;
}
// TODO(const_panic): assert_eq!(small_count, 0);
// This should only be possible with floating point arithmetic
// due to numerical instability; we should be fine without this loop:
while small_count > 0 {
small_count -= 1;
let l = small[small_count];
prob[l] = u32::MAX;
alias[l] = l;
}
AliasTable { prob, alias }
}
}
impl<const N: usize> Distribution<usize> for AliasTable<N> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let i = rng.gen_range(0..N);
let x = rng.gen::<u32>();
if x < self.prob[i] {
i
} else {
self.alias[i]
}
}
}
pub struct PopulationTable<T, const N: usize> {
items: [T; N],
distr: AliasTable<N>,
}
impl<T, const N: usize> PopulationTable<T, N> {
pub const fn new(items: [T; N], weights: [u32; N]) -> Self {
PopulationTable { items, distr: AliasTable::new(weights) }
}
}
impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
let idx = rng.sample(&self.distr);
self.items[idx].clone()
}
}
// TODO(macro_metavar_expr): The expansion of this macro will contain the weight array
// twice to automatically determine its length, which is literally redundant work.
macro_rules! population_table {
($v:vis $name:ident : $t:ty = [ $( $weight:expr => $item:expr ),+ $(,)? ] ) => {
$v const $name: PopulationTable<$t, {[$($weight),*].len()}> = PopulationTable::new(
[$($item),*], [$($weight),*]
);
}
}
population_table! {
NAME_TABLE: &'static str = [
2 => "Alice",
1 => "Bob",
3 => "Charlie",
]
}
pub fn main() {
assert_eq!(NAME_TABLE.distr.alias, [0, 2, 2]);
assert_eq!(NAME_TABLE.distr.prob, [u32::MAX, 1 << 31, u32::MAX]);
let mut rng = thread_rng();
let name = rng.sample(&NAME_TABLE);
println!("Hello, {}!", name);
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn greatest_common_divisor() {
assert_eq!(gcd(2, 4), 2);
assert_eq!(gcd(2, 5), 1);
assert_eq!(gcd(252, 105), 21);
}
#[test]
fn least_common_multiple() {
assert_eq!(lcm(2, 4), 4);
assert_eq!(lcm(2, 5), 10);
assert_eq!(lcm(18, 12), 36);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment