Binomial and Multinomial Generators
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.peterchng.binomial; | |
import java.util.Random; | |
/** | |
* Get samples from a binomial distribution. | |
* | |
* NOTE: Do not use this code in production unless you have thoroughly tested | |
* it and are comfortable that you know what it does! | |
*/ | |
public class BinomialSampler { | |
private final Random _random; | |
public BinomialSampler(final Random random) { | |
_random = random; | |
} | |
/** | |
* Returns a sample from a Binomial Distribution with parameters | |
* n (number of trials) and p (probability of success of an individual | |
* trial). | |
* | |
* The sample is the number of successes observed. | |
*/ | |
public int sample(final int n, final double p) { | |
// Basic error checking. | |
if (n <= 0) { | |
return 0; | |
} else if (p >= 1) { | |
return n; | |
} else if (p <= 0) { | |
return 0; | |
} | |
int successes = 0; | |
for (int i = 0; i < n; i++) { | |
if (binomialTrial(p)) { | |
successes++; | |
} | |
} | |
return successes; | |
} | |
// NOTE: Lacks input validation. | |
public boolean binomialTrial(final double p) { | |
// nextDouble() returns a uniformly-distributed value between 0.0 inclusive | |
// and 1.0 exclusive. | |
if (_random.nextDouble() < p) { | |
return true; | |
} | |
return false; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.peterchng; | |
import com.peterchng.binomial.BinomialSampler; | |
import com.peterchng.multinomial.MultinomialRecurrenceSampler; | |
import com.peterchng.multinomial.SimpleMultinomialSampler; | |
import org.openjdk.jmh.annotations.*; | |
import java.util.concurrent.ThreadLocalRandom; | |
import java.util.concurrent.TimeUnit; | |
@BenchmarkMode(Mode.AverageTime) | |
@Fork(value = 1, warmups = 5) | |
@OutputTimeUnit(TimeUnit.MILLISECONDS) | |
public class MultinomialBenchmark { | |
private static final int NUM_TRIALS = 1_000_000; | |
@Benchmark | |
public int[] testMultinomialRecurrenceSampler(final MultinomialRecurrenceSamplerState state) { | |
// NOTE: There's no Blackhole for an int[] | |
return state.sampler.multinomialSample(NUM_TRIALS, state.probabilities); | |
} | |
@State(Scope.Thread) | |
public static class MultinomialRecurrenceSamplerState { | |
// NOTE: Just using a constant set of probabilities. | |
public MultinomialRecurrenceSampler sampler = new MultinomialRecurrenceSampler(new BinomialSampler(ThreadLocalRandom.current())); | |
public double[] probabilities = {0.1, 0.2, 0.3, 0.4}; | |
} | |
@Benchmark | |
public int[] testSimpleMultinomialSampler(final SimpleMultinomialSamplerState state) { | |
// NOTE: There's no Blackhole for an int[] | |
return state.sampler.multinomialSample(NUM_TRIALS); | |
} | |
@State(Scope.Thread) | |
public static class SimpleMultinomialSamplerState { | |
// NOTE: Just using a constant set of probabilities. | |
public SimpleMultinomialSampler sampler = new SimpleMultinomialSampler(ThreadLocalRandom.current(), new double[]{0.1, 0.2, 0.3, 0.4}); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.peterchng.multinomial; | |
import com.peterchng.binomial.BinomialSampler; | |
/** | |
* Generates samples from a multinomial distribution by sampling each component | |
* from an appropriate binomial distribution. | |
* | |
* Based on "Non-Uniform Random Variate Generation" by Devroye (1986). See: | |
* - Chapter 11, Page 558: http://www.nrbook.com/devroye/Devroye_files/chapter_eleven.pdf | |
* - Errata: Page 559: http://www.nrbook.com/devroye/Devroye_files/errors.pdf | |
*/ | |
public class MultinomialRecurrenceSampler { | |
private final BinomialSampler _binomialSampler; | |
public MultinomialRecurrenceSampler(final BinomialSampler binomialSampler) { | |
_binomialSampler = binomialSampler; | |
} | |
/** | |
* Returns the result of a multinomial sample of n trials. | |
* The value of k (the number of possible outcomes) is determined by the | |
* number of probabilities passed in. | |
* <p> | |
* This assumes that `probabilities` sums to 1.0. | |
*/ | |
public int[] multinomialSample(final int n, final double[] probabilities) { | |
// Basic error checking: | |
if (n <= 0 || probabilities == null || probabilities.length == 0) { | |
return new int[0]; | |
} | |
final int[] result = new int[probabilities.length]; | |
double remainingSum = 1.0; | |
int remainingTrials = n; | |
for (int i = 0; i < probabilities.length; i++) { | |
// `_binomialSampler` returns a sample from a binomial distribution (n, p) | |
result[i] = _binomialSampler.sample(remainingTrials, probabilities[i] / remainingSum); | |
remainingSum -= probabilities[i]; | |
remainingTrials -= result[i]; | |
// Due to floating-point, even if the probabilities appear to sum to 1.0 the remainingSum | |
// may go ever so slightly negative on the last iteration. In this case, remainingTrials will | |
// also be 0. (But remainingTrials may go to 0 before the last possible iteration) | |
if (remainingSum <= 0 || remainingTrials == 0) { | |
break; | |
} | |
} | |
return result; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.peterchng.multinomial; | |
import java.util.Random; | |
/** | |
* Get samples from a multinomial distribution. | |
* | |
* Based on: https://github.com/tedunderwood/LDA/blob/master/Multinomial.java | |
* | |
* NOTE: Do not use this code in production unless you have thoroughly tested | |
* it and are comfortable that you know what it does! | |
*/ | |
public class SimpleMultinomialSampler { | |
private final Random _random; | |
private final double[] _probabilities; | |
// NOTE: Ideally should pass in a ThreadLocalRandom. | |
public SimpleMultinomialSampler(final Random random, final double[] probabilities) { | |
_random = random; | |
// Normalize all the passed in "probabilities" so that they sum to 1.0. | |
_probabilities = normalizeProbabilitiesToCumulative(probabilities); | |
} | |
/** | |
* Returns the result of a multinomial sample of n trials. | |
* The value of k (the number of possible outcomes) is determined by the | |
* number of probabilities passed into the constructor. | |
*/ | |
public int[] multinomialSample(final int n) { | |
// The length of `_probabilities` is the number of possible outcomes. | |
final int result[] = new int[_probabilities.length]; | |
// Get the result of each trial and increment the count for that outcome. | |
for (int i = 0; i < n; i++) { | |
result[multinomialTrial()]++; | |
} | |
return result; | |
} | |
/** | |
* The `_probabilities` field is an array of "cumulative" probabilities. | |
* The first element has the value p1, the second has p1 + p2, the third has p1 + p2 + p3, etc. | |
* By definition, the last bin should have a value of 1.0. | |
*/ | |
public int multinomialTrial() { | |
double sample = _random.nextDouble(); // Between [0, 1) | |
for (int i = 0; i < _probabilities.length; ++i) { | |
// Find the first bucket whose upper bound is above the sampled value. | |
if (sample < _probabilities[i]) { | |
return i; | |
} | |
} | |
// Catch-all return statement to ensure code compiles. | |
return _probabilities.length - 1; | |
} | |
/** | |
* Given an array of raw probabilities, this will transform the values in place in the the following manner: | |
* 1. The sum of the values will be computed. | |
* 2. Each value will be divided by the sum, normalizing them so that the sum is roughly 1.0. | |
* 3. The values will be converted into cumulative values. | |
* | |
* Example: Input is: [0.5, 0.5, 1.0] | |
* 1. Sum is 2.0 | |
* 2. After normalization: [0.25, 0.25, 0.5] | |
* 3. After converting to cumulative values: [0.25, 0.5, 1.0] | |
* | |
* The form in (3) is useful for converting a uniformly-sampled value between [0, 1) into a multinomial sample, | |
* because the values now represent the upper-bounds of the "range" between [0, 1) that the represent the probability | |
* of that outcome. Thus, given a uniformly-sampled value between [0, 1), we just need to find the first/lowest bin | |
* whose upper-bound is more than the sampled value. | |
*/ | |
public static double[] normalizeProbabilitiesToCumulative(final double[] probabilities) { | |
if (probabilities == null || probabilities.length < 2) { | |
throw new IllegalArgumentException("probabilities must have more than one value"); | |
} | |
double sum = 0.0; | |
for (double value: probabilities) { | |
sum += value; | |
} | |
double cumulative = 0.0; | |
final double[] distribution = new double[probabilities.length]; | |
for (int i = 0; i < probabilities.length; i++) { | |
cumulative += probabilities[i]; | |
distribution[i] = cumulative/sum; | |
} | |
// To ensure the right-most bin is always 1.0. | |
distribution[distribution.length - 1] = 1.0; | |
return distribution; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment