Created
September 22, 2022 23:10
-
-
Save robfitzgerald/0682ae2cb5145cb78f0f75f9cd5de14c to your computer and use it in GitHub Desktop.
implementation of Efraimidis & Spirakis' Algorithm A-ExpJ in vanilla Scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.util.Random | |
import scala.annotation.tailrec | |
object ReservoirSampling { | |
/** | |
* Use Efraimidis and Spirakis' "Algorithm A-ExpJ" for weighted reservoir sampling | |
* without replacement. | |
* | |
* [[see https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-ExpJ]] | |
* | |
* @param rng random generator | |
* @param population items with their non-negative numeric weight values | |
* @param k number of items to sample | |
* @return the population partitioned into a sample population and the remaining un-sampled population | |
*/ | |
def aExpJ[T]( | |
rng: Random, | |
population: List[(T, Double)], | |
k: Int | |
): (List[(T, Double)], List[(T, Double)]) = { | |
if (population.lengthCompare(k) <= 0) (population, List.empty) | |
else { | |
// create a min priority queue over the weighted population | |
implicit val ord: Ordering[(T, Double)] = Ordering.by { case (_, w) => -w } | |
val reservoir = scala.collection.mutable.PriorityQueue[(T, Double)]() | |
// compute a key value for entry into the reservoir from an item weight | |
def computeKey(w: Double, t: Option[Double] = None): Double = { | |
val randomValue = t match { | |
case None => rng.nextDouble // random(0, 1) case | |
case Some(value) => | |
val tValueTruncated = math.max(0.0, math.min(1.0, value)) | |
if (tValueTruncated == 1.0) 1.0 | |
else rng.between(tValueTruncated, 1.0) // random(t, 1) case | |
} | |
val pow = if (w == 0.0) 1.0 else 1.0 / w | |
math.pow(randomValue, pow) | |
} | |
// (re-)compute X via the min key value | |
def recomputeX(): Double = { | |
val (_, minKey) = reservoir.head | |
val x = math.log(rng.nextDouble) / math.log(math.max(1.0, minKey)) | |
x | |
} | |
// fill the reservoir with an initial solution | |
val initialSolutionEntries = population.take(k).map { case (t, w) => (t, computeKey(w)) } | |
reservoir.addAll(initialSolutionEntries) | |
// step through the remaining population and make updates to the solution queue, keeping | |
// track of all discarded values in the "notSampled" collection | |
@tailrec | |
def _solve( | |
x: Double, | |
remaining: List[(T, Double)], | |
notSampled: List[(T, Double)] = List.empty | |
): List[(T, Double)] = { | |
remaining match { | |
case Nil => notSampled | |
case (nextItem, nextWeight) :: tail => | |
val nextX = x - nextWeight | |
if (nextX > 0) _solve(nextX, tail, (nextItem, nextWeight) +: notSampled) | |
else { | |
// replace the min-weight value in the reservoir with this entry and reset X | |
val (_, minKey) = reservoir.head | |
val t = math.pow(minKey, nextWeight) | |
val nextKey = computeKey(nextWeight, Some(t)) | |
val removedItem = reservoir.dequeue() | |
reservoir.addOne((nextItem, nextWeight)) | |
val resetX = recomputeX() | |
_solve(resetX, tail, removedItem +: notSampled) | |
} | |
} | |
} | |
val notSampled = _solve(recomputeX(), population.drop(k)) | |
(reservoir.toList, notSampled) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment