Skip to content

Instantly share code, notes, and snippets.

@ashelly
Last active June 15, 2020 15:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ashelly/cc0cd6916235a04df49a to your computer and use it in GitHub Desktop.
Save ashelly/cc0cd6916235a04df49a to your computer and use it in GitHub Desktop.
Fast Weighted Random Sampling from discrete distributions. i.e Selecting items from a weighted list.
/*
WRS.c
(c) AShelly (github.com/ashelly)
Weighted random sampling with replacement of N items in O(1) time.
(After preparing a O(N) sized buffer in O(N) time.)
The concept is:
Randomly select a buffer index. Each index is selected with probablilty 1/N.
Each index stores the fraction of hits for which this item should be selected,
and the index of another item, which will be selected if this one is not.
Imagine a histogram contining [2,4,5,6,3]. It has 5 buckets, sum = 20.
We transform it by normalizing, then dividing each bucket's weight by 1/5.
Then we "fill up" the underpopulated buckets from the overpopulated ones.
AA 0.50 -> AACC 0.50 AACC 0.50 AACC 0.50
BBBB 1.00 BBBB 1.00 BBBB 1.00 BBBB 1.00
CCCCC 1.25 -> CCC 0.75(+0.50A) -> CCCD 0.75(+0.50A) CCCD 0.75(+0.50A)
DDDDDD 1.50 DDDDDD 1.50 -> DDDDD 1.25(+0.25C) -> DDDD 1.00(+0.25C+0.25E)
EEE 0.75 EEE 0.75 EEE 0.75 -> EEED 0.75
Independently invented by me in 2014, only 40 years after the concept was published by
A. J. Walker in 1974 as "New fast method for generating discrete random numbers with
arbitrary frequency distributions". See http://en.wikipedia.org/wiki/Alias_method
*/
#include <stdio.h>
#include <stdlib.h>
//helper for demo only. !!Not ideal uniform distribution.!!
double rand_percent() {
return ((double)rand())/RAND_MAX;
}
//WRS data structure
typedef struct wrs_data {
size_t N; //number items
double *share; //fractional share
size_t *pair; //remainder here
} wrs_t;
//Create the pre-processed data structure, allowing fast weighted selection.
//weights do not have to be normalized, we will do that first.
wrs_t* wrs_create(double* weights, size_t N) {
//make space
wrs_t* data = malloc(sizeof(wrs_t));
data->share = malloc(N *sizeof(double));
data->pair = malloc(N * sizeof(size_t));
data->N = N;
double sum = 0;
size_t i, j, k;
//Normalize and find what fraction of the ideal distribution is in each bucket.
//Set bucket's initial partner to self. acts as 'unprocessed' marker, and
// handles small rounding errors: if there are no big buckets left, excess goes back to self.
for (i=0; i<N; i++) { sum += weights[i]; }
for (i=0; i<N; i++) {
data->share[i] = weights[i] / (sum/N);
data->pair[i] = i;
}
//Find first overpopulated bucket
for (j=0; j<N && !(data->share[j] > 1.0); j++) {/*seek*/}
for (i=0; i<N; i++) {
k = i; // k is bucket under consideration
if (data->pair[k]!=i) continue; // reject already considered buckets
//If this bucket has less samples than a flat distribution,
//it will be selected more frequently than it should be.
double excess = 1.0 - data->share[k];
while (excess > 0 ) {
if (j == N) { break; } // no more partners, close enough.
printf("moving %.5f from %ld to %ld\n",excess,k,j);
data->pair[k]=j; // send excess hits to another bucket
data->share[j] -= excess; // account for increased selection rate
excess = 1.0 - data->share[j];
// If new bucket is now underpopulated, repeat with next over-full bucket
if (excess >= 0) {
for (k=j++; j<N && !(data->share[j] > 1.0); j++) {/*seek*/}
}
}
}
return data;
}
//O(1) weighted random sampling.
//Choose a real number. Treat as distance into the array.
//if the fractional part is greater than that bucket's allocation, use it's paired bucket.
size_t wrs_pick(wrs_t* data)
{
double pct = rand_percent()*data->N;
size_t idx = (int)pct;
if (pct-idx > data->share[idx]) { idx = data->pair[idx]; }
return idx;
}
//Clean up nicely
void wrs_destroy(wrs_t* data){
free(data->pair);
free(data->share);
free(data);
}
//util to back out the normalized weights from the pre-processed data.
double* wrs_norm(wrs_t* data) {
double* pct = calloc(data->N, sizeof(double));
size_t i;
for (i=0;i<data->N;i++) {
pct[i]+=data->share[i];
pct[data->pair[i]]+=1.0-data->share[i];
}
return pct;
}
/** sample usage **/
//double weights[]= {20,1,4,10,15,10,16,10,8,6};
double weights[]= {2,4,5,6,3};
#define NW (sizeof(weights)/sizeof(weights[0]))
#define TRIALS 1000
int main(int argc,char*argv[]){
//pre-process input data
wrs_t* dist = wrs_create(weights, NW);
//show normalized weights
int i;
double* d=wrs_norm(dist);
printf("\n");
for (i=0;i<NW;i++) {
printf("%.5f ",d[i]/NW);
}
free(d);
//build new histogram
int samples[NW]={0};
for (i=0; i<TRIALS; i++){
samples[wrs_pick(dist)]++;
}
//show that generated data matches
printf("\n--------------\n");
for (i=0;i<NW;i++) { printf("%.5f ",(double)samples[i]/TRIALS);}
printf("\n");
//cleanup
wrs_destroy(dist);
return 0;
}
@zoople
Copy link

zoople commented Jun 15, 2020

Thank you for this :) Its a very clever way to do it. I had a go at this and while it works, my implementation doesnt perform faster than a linear search? Maybe I'm doing something wrong? I've posted about it here if you are interested: https://stackoverflow.com/questions/62391780/walkers-alias-method-for-weighted-random-selection-isnt-faster-than-a-linear-se

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment