Skip to content

Instantly share code, notes, and snippets.

@pchng
Last active October 24, 2020 15:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pchng/f972831980278d7ba63b6c8f1bcae0e1 to your computer and use it in GitHub Desktop.
Save pchng/f972831980278d7ba63b6c8f1bcae0e1 to your computer and use it in GitHub Desktop.
Binomial and Multinomial Generators
package com.peterchng.binomial;
import java.util.Random;
/**
* Get samples from a binomial distribution.
*
* NOTE: Do not use this code in production unless you have thoroughly tested
* it and are comfortable that you know what it does!
*/
public class BinomialSampler {
private final Random _random;
public BinomialSampler(final Random random) {
_random = random;
}
/**
* Returns a sample from a Binomial Distribution with parameters
* n (number of trials) and p (probability of success of an individual
* trial).
*
* The sample is the number of successes observed.
*/
public int sample(final int n, final double p) {
// Basic error checking.
if (n <= 0) {
return 0;
} else if (p >= 1) {
return n;
} else if (p <= 0) {
return 0;
}
int successes = 0;
for (int i = 0; i < n; i++) {
if (binomialTrial(p)) {
successes++;
}
}
return successes;
}
// NOTE: Lacks input validation.
public boolean binomialTrial(final double p) {
// nextDouble() returns a uniformly-distributed value between 0.0 inclusive
// and 1.0 exclusive.
if (_random.nextDouble() < p) {
return true;
}
return false;
}
}
package com.peterchng;
import com.peterchng.binomial.BinomialSampler;
import com.peterchng.multinomial.MultinomialRecurrenceSampler;
import com.peterchng.multinomial.SimpleMultinomialSampler;
import org.openjdk.jmh.annotations.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@Fork(value = 1, warmups = 5)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public class MultinomialBenchmark {
private static final int NUM_TRIALS = 1_000_000;
@Benchmark
public int[] testMultinomialRecurrenceSampler(final MultinomialRecurrenceSamplerState state) {
// NOTE: There's no Blackhole for an int[]
return state.sampler.multinomialSample(NUM_TRIALS, state.probabilities);
}
@State(Scope.Thread)
public static class MultinomialRecurrenceSamplerState {
// NOTE: Just using a constant set of probabilities.
public MultinomialRecurrenceSampler sampler = new MultinomialRecurrenceSampler(new BinomialSampler(ThreadLocalRandom.current()));
public double[] probabilities = {0.1, 0.2, 0.3, 0.4};
}
@Benchmark
public int[] testSimpleMultinomialSampler(final SimpleMultinomialSamplerState state) {
// NOTE: There's no Blackhole for an int[]
return state.sampler.multinomialSample(NUM_TRIALS);
}
@State(Scope.Thread)
public static class SimpleMultinomialSamplerState {
// NOTE: Just using a constant set of probabilities.
public SimpleMultinomialSampler sampler = new SimpleMultinomialSampler(ThreadLocalRandom.current(), new double[]{0.1, 0.2, 0.3, 0.4});
}
}
package com.peterchng.multinomial;
import com.peterchng.binomial.BinomialSampler;
/**
* Generates samples from a multinomial distribution by sampling each component
* from an appropriate binomial distribution.
*
* Based on "Non-Uniform Random Variate Generation" by Devroye (1986). See:
* - Chapter 11, Page 558: http://www.nrbook.com/devroye/Devroye_files/chapter_eleven.pdf
* - Errata: Page 559: http://www.nrbook.com/devroye/Devroye_files/errors.pdf
*/
public class MultinomialRecurrenceSampler {
private final BinomialSampler _binomialSampler;
public MultinomialRecurrenceSampler(final BinomialSampler binomialSampler) {
_binomialSampler = binomialSampler;
}
/**
* Returns the result of a multinomial sample of n trials.
* The value of k (the number of possible outcomes) is determined by the
* number of probabilities passed in.
* <p>
* This assumes that `probabilities` sums to 1.0.
*/
public int[] multinomialSample(final int n, final double[] probabilities) {
// Basic error checking:
if (n <= 0 || probabilities == null || probabilities.length == 0) {
return new int[0];
}
final int[] result = new int[probabilities.length];
double remainingSum = 1.0;
int remainingTrials = n;
for (int i = 0; i < probabilities.length; i++) {
// `_binomialSampler` returns a sample from a binomial distribution (n, p)
result[i] = _binomialSampler.sample(remainingTrials, probabilities[i] / remainingSum);
remainingSum -= probabilities[i];
remainingTrials -= result[i];
// Due to floating-point, even if the probabilities appear to sum to 1.0 the remainingSum
// may go ever so slightly negative on the last iteration. In this case, remainingTrials will
// also be 0. (But remainingTrials may go to 0 before the last possible iteration)
if (remainingSum <= 0 || remainingTrials == 0) {
break;
}
}
return result;
}
}
package com.peterchng.multinomial;
import java.util.Random;
/**
* Get samples from a multinomial distribution.
*
* Based on: https://github.com/tedunderwood/LDA/blob/master/Multinomial.java
*
* NOTE: Do not use this code in production unless you have thoroughly tested
* it and are comfortable that you know what it does!
*/
public class SimpleMultinomialSampler {
private final Random _random;
private final double[] _probabilities;
// NOTE: Ideally should pass in a ThreadLocalRandom.
public SimpleMultinomialSampler(final Random random, final double[] probabilities) {
_random = random;
// Normalize all the passed in "probabilities" so that they sum to 1.0.
_probabilities = normalizeProbabilitiesToCumulative(probabilities);
}
/**
* Returns the result of a multinomial sample of n trials.
* The value of k (the number of possible outcomes) is determined by the
* number of probabilities passed into the constructor.
*/
public int[] multinomialSample(final int n) {
// The length of `_probabilities` is the number of possible outcomes.
final int result[] = new int[_probabilities.length];
// Get the result of each trial and increment the count for that outcome.
for (int i = 0; i < n; i++) {
result[multinomialTrial()]++;
}
return result;
}
/**
* The `_probabilities` field is an array of "cumulative" probabilities.
* The first element has the value p1, the second has p1 + p2, the third has p1 + p2 + p3, etc.
* By definition, the last bin should have a value of 1.0.
*/
public int multinomialTrial() {
double sample = _random.nextDouble(); // Between [0, 1)
for (int i = 0; i < _probabilities.length; ++i) {
// Find the first bucket whose upper bound is above the sampled value.
if (sample < _probabilities[i]) {
return i;
}
}
// Catch-all return statement to ensure code compiles.
return _probabilities.length - 1;
}
/**
* Given an array of raw probabilities, this will transform the values in place in the the following manner:
* 1. The sum of the values will be computed.
* 2. Each value will be divided by the sum, normalizing them so that the sum is roughly 1.0.
* 3. The values will be converted into cumulative values.
*
* Example: Input is: [0.5, 0.5, 1.0]
* 1. Sum is 2.0
* 2. After normalization: [0.25, 0.25, 0.5]
* 3. After converting to cumulative values: [0.25, 0.5, 1.0]
*
* The form in (3) is useful for converting a uniformly-sampled value between [0, 1) into a multinomial sample,
* because the values now represent the upper-bounds of the "range" between [0, 1) that the represent the probability
* of that outcome. Thus, given a uniformly-sampled value between [0, 1), we just need to find the first/lowest bin
* whose upper-bound is more than the sampled value.
*/
public static double[] normalizeProbabilitiesToCumulative(final double[] probabilities) {
if (probabilities == null || probabilities.length < 2) {
throw new IllegalArgumentException("probabilities must have more than one value");
}
double sum = 0.0;
for (double value: probabilities) {
sum += value;
}
double cumulative = 0.0;
final double[] distribution = new double[probabilities.length];
for (int i = 0; i < probabilities.length; i++) {
cumulative += probabilities[i];
distribution[i] = cumulative/sum;
}
// To ensure the right-most bin is always 1.0.
distribution[distribution.length - 1] = 1.0;
return distribution;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment