Skip to content

Instantly share code, notes, and snippets.

@tempredirect
Last active August 29, 2015 13:57
Show Gist options
  • Save tempredirect/9853163 to your computer and use it in GitHub Desktop.
Save tempredirect/9853163 to your computer and use it in GitHub Desktop.
Random weighted selection
Iterations:100
{0=14, 1=22, 2=25, 3=27, 4=7, 5=5}
0 - prop:1.260 weight:1.000 diff:0.26000
1 - prop:1.980 weight:2.000 diff:-0.02000
2 - prop:2.250 weight:2.000 diff:0.25000
3 - prop:2.430 weight:3.000 diff:-0.57000
4 - prop:0.630 weight:0.500 diff:0.13000
5 - prop:0.450 weight:0.500 diff:-0.05000
squared error: 0.474800
Iterations:10000
{0=1154, 1=2219, 2=2210, 3=3331, 4=562, 5=524}
0 - prop:1.039 weight:1.000 diff:0.03860
1 - prop:1.997 weight:2.000 diff:-0.00290
2 - prop:1.989 weight:2.000 diff:-0.01100
3 - prop:2.998 weight:3.000 diff:-0.00210
4 - prop:0.506 weight:0.500 diff:0.00580
5 - prop:0.472 weight:0.500 diff:-0.02840
squared error: 0.002464
Iterations:1000000
{0=111753, 1=222214, 2=222173, 3=332759, 4=55595, 5=55506}
0 - prop:1.006 weight:1.000 diff:0.00578
1 - prop:2.000 weight:2.000 diff:-0.00007
2 - prop:2.000 weight:2.000 diff:-0.00044
3 - prop:2.995 weight:3.000 diff:-0.00517
4 - prop:0.500 weight:0.500 diff:0.00035
5 - prop:0.500 weight:0.500 diff:-0.00045
squared error: 0.000061
Iterations:100000000
{0=11116046, 1=22220971, 2=22215013, 3=33337796, 4=5553095, 5=5557079}
0 - prop:1.000 weight:1.000 diff:0.00044
1 - prop:2.000 weight:2.000 diff:-0.00011
2 - prop:1.999 weight:2.000 diff:-0.00065
3 - prop:3.000 weight:3.000 diff:0.00040
4 - prop:0.500 weight:0.500 diff:-0.00022
5 - prop:0.500 weight:0.500 diff:0.00014
squared error: 0.000001
import java.util.*;
import java.util.concurrent.*;
public class RandomWeightedSelection {
public static void main(String [] args) {
runIterations( 100);
runIterations( 10_000);
runIterations( 1_000_000);
runIterations(100_000_000);
}
public static void runIterations(int iterations) {
System.out.println("Iterations:" + iterations);
double [] weights = {1.0, 2.0, 2.0, 3.0, 0.5, 0.5};
double totalWeight = 0.0;
Map<Integer,Integer> counts = new LinkedHashMap<>();
for (int i = 0; i < weights.length ; i++) {
counts.put(i, 0);
totalWeight += weights[i];
}
for (int i = 0; i < iterations ; i ++ ) {
int selection = select(weights);
counts.put(selection, counts.get(selection) + 1);
}
System.out.println(counts);
double error = 0.0;
for (Map.Entry<Integer,Integer> entry : counts.entrySet()) {
double weight = weights[entry.getKey()];
double prop = entry.getValue().doubleValue() * totalWeight / iterations;
double diff = prop - weight;
error += diff * diff;
System.out.printf("%d - prop:%.3f weight:%.3f diff:%.5f\n",
entry.getKey(),
prop,
weight,
diff);
}
System.out.printf(" squared error: %.6f\n", error);
}
/**
* selects an index of an array with based on the weights.
*/
public static int select(double [] weights) {
double total = 0.0;
for (int i = 0; i < weights.length; i ++ ) {
total += weights[i];
}
double seed = ThreadLocalRandom.current().nextDouble(total);
double acc = 0.0;
for(int i = 0; i < weights.length; i ++ ) {
acc += weights[i];
if( seed < acc ) {
return i;
}
}
return weights.length -1; // else it's the last element
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment