Skip to content

Instantly share code, notes, and snippets.

@groupsky
Last active January 4, 2017 17:05
Show Gist options
  • Save groupsky/2f59e3c2415f3d305273df316554d55f to your computer and use it in GitHub Desktop.
Save groupsky/2f59e3c2415f3d305273df316554d55f to your computer and use it in GitHub Desktop.
import java.lang.Math;
import java.lang.System;
public class fast_roulette {
/* program n_select=1000 times selects one of n=4 elements with weights weight[i].
* Selections are summed up in counter[i]. For the weights as given in the example
* below one expects that elements 0,1,2 and 3 will be selected (on average)
* 200, 150, 600 and 50 times, respectively. In good agreement with exemplary run.
*/
static double comb(long n, long k) {
if (k>n) return 0;
if (k > n-k) k = n-k;
double r = 1;
for (long d=1; d <= k; ++d) {
r *= n--;
r /= d;
}
return r;
}
public static void main(String [] args) {
long n_select=1000l;
int n=3;
double [] weight = new double [n];
double sum_weight=0;
double max_weight=0;
double min_weight=Double.MAX_VALUE;
for (int i=n; i-->0;) {
weight[i] = Math.random()*Math.sqrt(n);
weight[i] = 1;
sum_weight += weight[i];
max_weight = Math.max(max_weight, weight[i]);
min_weight = Math.min(min_weight, weight[i]);
}
long [] counter = new long[n];
long [] counter2 = new long[n];
int index=0;
long left_select = n_select;
double c;
long tmp;
long tmp2;
long selected=0,selected2=0;
long count;
long perf3;
long perf2 = System.nanoTime();
left_select = n_select;
while (left_select > 0){
tmp = 1;
perf3 = System.nanoTime();
while (true){
tmp++;
index= (int)(n*Math.random());
if (left_select > n*n) {
count = (int)(left_select*Math.random())+1;
c = comb(left_select, count)*Math.pow(weight[index]/sum_weight, count)*Math.pow(1-weight[index]/sum_weight, left_select-count);
} else {
count = 1;
c = weight[index]/max_weight;
}
if(Math.random()<c) {break;}
}
counter2[index] += count;
selected2 += count;
perf3 = System.nanoTime() - perf3;
System.out.println("count ="+count+"/"+left_select+" "+tmp+" "+perf3);
left_select -= count;
}
perf2 = System.nanoTime() - perf2;
long perf = System.nanoTime();
if (n_select < 1000000)
for (int i=0; i<n_select; i++){
while (true){
index= (int)(n*Math.random());
if(Math.random()<weight[index]/max_weight) {break;}
}
counter[index]++;
selected++;
}
perf = System.nanoTime() - perf;
System.out.println("perf ="+perf);
System.out.println("perf2="+perf2);
System.out.println("selected ="+selected);
System.out.println("selected2="+selected2);
for (int i=0; i<n; i++){
System.out.println("control ["+i+"]="+(int)(weight[i]*n_select/sum_weight));
System.out.println("counter ["+i+"]="+counter[i]);
System.out.println("counter2["+i+"]="+counter2[i]);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment