Skip to content

Instantly share code, notes, and snippets.

Created March 14, 2012 02:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/2033568 to your computer and use it in GitHub Desktop.
Save anonymous/2033568 to your computer and use it in GitHub Desktop.
Weighted Selection
import scala.util.Random
/**
* @author Andrew Conway
*/
object WeightedRandomSelection {
/**
* Get the number of times an event with probability p occurs in N samples.
* if R is res, then P(R=n) = p^n q^(N-n) N! / n! / (N-n)!
* where q = 1-p
* This has the property that P(R=0) = q^N, and
* P(R=n+1) = p/q (N-n)/(n+1) P(R=n)
* Also note that P(R=n+1|R>n) = P(R=n+1)/P(R>n)
* Uses these facts to work out the probability that the result is zero. If
* not, then the prob that given that, the result is 1, etc.
*/
def numEntries(p:Double,N:Int,r:Random) : Int = if (p>0.5) N-numEntries(1.0-p,N,r) else if (p<0.0) 0 else {
var n = 0
val q = 1.0-p
var prstop = Math.pow(q,N)
var cumulative = 0.0
while (n<N && (r.nextDouble()*(1-cumulative))>=prstop) {
cumulative+=prstop
prstop*=p*(N-n)/(q*(n+1))
n+=1
}
n
}
case class WeightedItem[T](item: T, weight: Double)
/**
* Compute a weighted selection from the given items.
* cumulativeSum must be the same length as items (or longer), with the ith element containing the sum of all
* weights from the item i to the end of the list. This is done in a saved way rather than adding up and then
* subtracting in order to prevent rounding errors from causing a variety of subtle problems.
*/
private def weightedSelectionWithCumSum[T](items: Seq[WeightedItem[T]],cumulativeSum:List[Double], numSelections:Int, r: Random) : Seq[T] = {
if (numSelections==0) Nil
else {
val head = items.head
val nhead = numEntries(head.weight/cumulativeSum.head,numSelections,r)
List.fill(nhead)(head.item)++weightedSelectionWithCumSum(items.tail,cumulativeSum.tail,numSelections-nhead,r)
}
}
def weightedSelection[T](items: Seq[WeightedItem[T]], numSelections:Int, r: Random): Seq[T] = {
val cumsum = items.foldRight(List(0.0)){(wi,l)=>(wi.weight+l.head)::l}
weightedSelectionWithCumSum(items,cumsum,numSelections,r)
}
def testRandomness[T](items: Seq[WeightedItem[T]], numSelections:Int, r: Random) {
val runs = 10000
val indexOfItem = Map.empty++items.zipWithIndex.map{case (item,ind)=>item.item->ind}
val numItems = items.length
val bucketSums = new Array[Double](numItems)
val bucketSumSqs = new Array[Double](numItems)
for (run<-0 until runs) {
// compute chi-squared for a run
val runresult = weightedSelection(items,numSelections,r)
val buckets = new Array[Double](numItems)
for (r<-runresult) buckets(indexOfItem(r))+=1
for (i<-0 until numItems) {
val count = buckets(i)
bucketSums(i)+=count
bucketSumSqs(i)+=count*count
}
}
val sumWeights = items.foldLeft(0.0)(_+_.weight)
for ((item,ind)<-items.zipWithIndex) {
val p = item.weight/sumWeights
val mean = bucketSums(ind)/runs
val variance = bucketSumSqs(ind)/runs-mean*mean
val expectedMean = numSelections*p
val expectedVariance = numSelections*p*(1-p)
val expectedErrorInMean = Math.sqrt(expectedVariance/runs)
val text = "Item %10s Mean %.3f Expected %.3f±%.3f Variance %.3f expected %.3f".format(item.item,mean,expectedMean,expectedErrorInMean,variance,expectedVariance)
println(text)
}
}
def main(args: Array[String]): Unit = {
val items = Seq(WeightedItem("Red", 1d/6), WeightedItem("Blue", 2d/6), WeightedItem("Green", 3d/6) )
println(weightedSelection(items, 6, new Random()))
testRandomness(items, 6, new Random())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment