Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Sketch of a reasonably efficient algorithm for weighted sampling without replacement
def cuml(wt: Array[Double]) = {
val l = wt.length
val base = if ((l & (l-1)) == 0) l else java.lang.Integer.highestOneBit(l)*2
val tree = new Array[Double](base*2)
System.arraycopy(wt, 0, tree, 0, wt.length)
var in = 0
var out = base
var n = base
while (in + 1 < out) {
while (in + 1 < n) {
tree(out) = tree(in) + tree(in+1)
out += 1
in += 2
}
n = n | (n >>> 1)
}
tree
}
def seek(tree: Array[Double], p: Double)(zero: Int = tree.length-4, index: Int = 0, stride: Int = 2): Int = {
if (zero == 0) index + (if (p < tree(index)) 0 else 1)
else if (p < tree(zero + index)) seek(tree, p)(zero - (stride << 1), index << 1, stride << 1)
else seek(tree, p - tree(zero + index))(zero - (stride << 1), (index << 1) + 2, stride << 1)
}
def wipe(tree: Array[Double], index: Int)(value: Double = tree(index), width: Int = tree.length >> 1) {
tree(index) -= value
if (width > 1) wipe(tree, (tree.length+index) >> 1)(value, width >> 1)
}
// Random number generator should generate values in (0, 1]
def sample(r: () => Double, wt: Array[Double], k: Int): Array[Int] = {
val indices = new Array[Int](k)
val tree = cuml(wt)
var i = 0
while (i < k) {
val index = seek(tree, r()*tree(tree.length-2))()
wipe(tree, index)()
indices(i) = index
i += 1
}
indices
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment