Skip to content

Instantly share code, notes, and snippets.

@sorokod
Last active August 31, 2016 15:11
Show Gist options
  • Save sorokod/a53e4781c0586477523a to your computer and use it in GitHub Desktop.
Save sorokod/a53e4781c0586477523a to your computer and use it in GitHub Desktop.
Sample a population of T according to provided distribution
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import static java.util.Arrays.binarySearch;
/**
* Sample a population of T given distribution provided by consecutive calls to update().
* <p>
* Example invocation:
* <pre>
* Sampler<String> sampler = new Sampler<>( random );
* sampler.update("A");
* sampler.update("C");
* sampler.update("B");
* sampler.update("C");
*
* sampler.fix();
*
* sampler.sample(); // returns A B or C with prob: 1/4 1/4 1/2
*
* </pre>
* <p>
* Created by David Soroko
*/
public class Sampler<T> {
private final Random random;
private Map<T, Integer> countMap = new HashMap<>();
private T[] words;
private int[] wordCount;
private int sum;
public Sampler(Random random) {
this.random = random;
}
public void update(T t) {
countMap.merge(t, 1, (k, kk) -> ++k);
}
public void fix() {
words = (T[]) new Object[countMap.size()];
if (countMap.size() == 1) {
words[0] = countMap.keySet().iterator().next();
} else {
wordCount = new int[countMap.size()];
int i = 0;
for (Map.Entry<T, Integer> entry : countMap.entrySet()) {
sum += entry.getValue();
wordCount[i] = sum;
words[i] = entry.getKey();
i++;
}
}
countMap = null;
}
public T sample() {
if (words.length == 1) {
return words[0];
}
int rand = random.nextInt(sum) + 1; // rand is in [1..sum]
int index = binarySearch(wordCount, rand);
if (index < 0) {
index = -index - 1;
}
return words[index];
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment