Last active
August 31, 2016 15:11
-
-
Save sorokod/a53e4781c0586477523a to your computer and use it in GitHub Desktop.
Sample a population of T according to provided distribution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.HashMap; | |
import java.util.Map; | |
import java.util.Random; | |
import static java.util.Arrays.binarySearch; | |
/** | |
* Sample a population of T given distribution provided by consecutive calls to update(). | |
* <p> | |
* Example invocation: | |
* <pre> | |
* Sampler<String> sampler = new Sampler<>( random ); | |
* sampler.update("A"); | |
* sampler.update("C"); | |
* sampler.update("B"); | |
* sampler.update("C"); | |
* | |
* sampler.fix(); | |
* | |
* sampler.sample(); // returns A B or C with prob: 1/4 1/4 1/2 | |
* | |
* </pre> | |
* <p> | |
* Created by David Soroko | |
*/ | |
public class Sampler<T> { | |
private final Random random; | |
private Map<T, Integer> countMap = new HashMap<>(); | |
private T[] words; | |
private int[] wordCount; | |
private int sum; | |
public Sampler(Random random) { | |
this.random = random; | |
} | |
public void update(T t) { | |
countMap.merge(t, 1, (k, kk) -> ++k); | |
} | |
public void fix() { | |
words = (T[]) new Object[countMap.size()]; | |
if (countMap.size() == 1) { | |
words[0] = countMap.keySet().iterator().next(); | |
} else { | |
wordCount = new int[countMap.size()]; | |
int i = 0; | |
for (Map.Entry<T, Integer> entry : countMap.entrySet()) { | |
sum += entry.getValue(); | |
wordCount[i] = sum; | |
words[i] = entry.getKey(); | |
i++; | |
} | |
} | |
countMap = null; | |
} | |
public T sample() { | |
if (words.length == 1) { | |
return words[0]; | |
} | |
int rand = random.nextInt(sum) + 1; // rand is in [1..sum] | |
int index = binarySearch(wordCount, rand); | |
if (index < 0) { | |
index = -index - 1; | |
} | |
return words[index]; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment