Last active
December 9, 2023 12:34
-
-
Save inutamago-dogegg/80dcc6809c0cf98ec349bf03b2da6cc2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
public class AliasMethod { | |
//https://en.wikipedia.org/wiki/Alias_method | |
//https://stackoverflow.com/a/39199014 | |
//https://www.keithschwarz.com/darts-dice-coins/ | |
public AliasMethod() { } | |
public AliasMethod(float zeroThreshold) { | |
SetZeroThreshold(zeroThreshold); | |
} | |
/// <summary> | |
/// if the probability is less than this value, it is treated as 0 | |
/// </summary> | |
public float ZeroThreshold { get; private set; } = 0f; | |
public bool ApplyZeroThreshold => ZeroThreshold > 0f; | |
/// <summary> | |
/// set the threshold for treating the probability as 0 | |
/// </summary> | |
/// <param name="zeroThreshold"></param> | |
/// <exception cref="ArgumentOutOfRangeException"></exception> | |
public void SetZeroThreshold(float zeroThreshold) { | |
if (zeroThreshold is < 0f or > 1f) { | |
throw new ArgumentOutOfRangeException(nameof(zeroThreshold)); | |
} | |
ZeroThreshold = zeroThreshold; | |
} | |
private int _n; | |
private float[] _probabilities; | |
private int[] _alias; | |
/// <summary> | |
/// construct probability distribution | |
/// </summary> | |
/// <param name="weights"></param> | |
public void Constructor(float[] weights) { | |
int n; | |
float sum; | |
float[] p; | |
if (!ApplyZeroThreshold) { | |
sum = weights.Sum(); | |
n = weights.Length; | |
p = weights.Select(x => x / sum * n).ToArray(); | |
} | |
else { | |
sum = weights.Where(x => x > ZeroThreshold).Sum(); | |
n = weights.Count(x => x / sum > ZeroThreshold); | |
p = weights.Where(x => x / sum > ZeroThreshold).Select(x => x / sum * n).ToArray(); | |
} | |
var prob = new float[n]; | |
var alias = new int[n]; | |
Array.Fill(prob, 0f); | |
Array.Fill(alias, 0); | |
var small = new Queue<int>(); | |
var large = new Queue<int>(); | |
foreach (var (pp, i) in p.Select((pp, i) => (pp, i))) { | |
if (pp < 1) { | |
small.Enqueue(i); | |
} | |
else { | |
large.Enqueue(i); | |
} | |
} | |
while (small.Count > 0 && large.Count > 0) { | |
var l = small.Dequeue(); | |
var g = large.Dequeue(); | |
prob[l] = p[l]; | |
alias[l] = g; | |
p[g] = p[g] + p[l] - 1; | |
if (p[g] < 1) { | |
small.Enqueue(g); | |
} | |
else { | |
large.Enqueue(g); | |
} | |
} | |
while (large.Count > 0) { | |
var g = large.Dequeue(); | |
prob[g] = 1; | |
} | |
while (small.Count > 0) { | |
var l = small.Dequeue(); | |
prob[l] = 1; | |
} | |
_n = n; | |
_probabilities = prob; | |
_alias = alias; | |
} | |
/// <summary> | |
/// roll the dice | |
/// </summary> | |
/// <returns></returns> | |
public int Roll() { | |
var i = new Random().Next(_n); | |
return new Random().NextDouble() < _probabilities[i] ? i : _alias[i]; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment