Skip to content

Instantly share code, notes, and snippets.

@emesday
Last active August 11, 2020 17:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save emesday/87e877bd21711dcf1fb8e4a2deed032d to your computer and use it in GitHub Desktop.
Save emesday/87e877bd21711dcf1fb8e4a2deed032d to your computer and use it in GitHub Desktop.
Reservoir Sampling for Scala Spark
import scala.reflect.ClassTag
import scala.util.Random
class Reservoir[T: ClassTag](
private val size: Int,
private val seed: Long = Random.nextLong()) extends Serializable {
private val rand = new Random(seed)
private val reservoir = new Array[T](size)
private var count = 0L
def +=(elem: T): this.type = {
count += 1
if (count <= size) {
reservoir((count - 1).toInt) = elem
} else {
val replacementIndex = (rand.nextDouble() * count).toLong
if (replacementIndex < size) {
reservoir(replacementIndex.toInt) = elem
}
}
this
}
def ++=(that: Reservoir[T]): this.type = {
require(this.size == that.size)
if ((this.count + that.count) < size) {
compat.Platform.arraycopy(
that.reservoir, 0, this.reservoir, this.count.toInt, that.count.toInt)
} else {
val thisIterator = rand.shuffle[Int, IndexedSeq](0 until this.getSize).iterator
val thatIterator = rand.shuffle[Int, IndexedSeq](0 until that.getSize).iterator
val thisProb = this.count.toDouble / (this.count + that.count)
val newReservoir = Array.fill[T](size) {
if (thisIterator.isEmpty) {
that.reservoir(thatIterator.next())
} else if (thatIterator.isEmpty) {
this.reservoir(thisIterator.next())
} else {
if (rand.nextDouble() < thisProb) {
this.reservoir(thisIterator.next())
} else {
that.reservoir(thatIterator.next())
}
}
}
compat.Platform.arraycopy(newReservoir, 0, reservoir, 0, newReservoir.length)
}
count += that.count
this
}
def getCount: Long = count
def getSize: Int = math.min(size, count).toInt
def result(): Array[T] = reservoir.take(getSize)
}
@emesday
Copy link
Author

emesday commented Mar 12, 2019

val rdd: RDD[Array[Float]] = _
val n = 100

rdd.treeAggregate(new Reservoir[Array[Float]](n))(
      seqOp = (agg, arr) => agg += arr,
      combOp = (agg, other) => agg ++= other
    ).result()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment