Skip to content

Instantly share code, notes, and snippets.

@felipecrv
Created June 30, 2022 06:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save felipecrv/96d814bdaaeacdfcba973c3ffe250c57 to your computer and use it in GitHub Desktop.
Save felipecrv/96d814bdaaeacdfcba973c3ffe250c57 to your computer and use it in GitHub Desktop.
Sampling and Shuffling
package rs.felipe.random
import java.util.Random
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
object Sampling {
/**
* Randomly sample up to k items from a sequence.
*/
def sample[T](sequence: IndexedSeq[T], k: Int, rng: Random): ArrayBuffer[T] = {
val n = sequence.length
if (k < n / 2) {
smallSample(sequence, k, rng)
} else {
reservoirSample(sequence.toIterator, k, rng)
}
}
/**
* Randomly sample up to k items from a sequence and guarantee that the sample is shuffled.
*/
def shuffledSample[T](sequence: IndexedSeq[T], k: Int, rng: Random): ArrayBuffer[T] = {
val n = sequence.length
if (k < n / 2) {
smallSample(sequence, k, rng) // already shuffled
} else {
val sample = reservoirSample(sequence.toIterator, k, rng)
shuffle(sample, rng)
sample
}
}
/**
* Randomly sample k items from an iterator.
*
* Algorithm R -- See https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_R
* Vitter, Jeffrey S. "Random sampling with a reservoir", 1985.
*
* @return Sample of up to k items. Not properly shuffled.
*/
def reservoirSample[T](iterator: Iterator[T], k: Int, rng: Random): ArrayBuffer[T] = {
val sample = new ArrayBuffer[T](k)
var i = 0
while (i < k && iterator.hasNext) {
val elem = iterator.next()
sample.append(elem)
i += 1
}
while (iterator.hasNext) {
i += 1
val elem = iterator.next()
val r = rng.nextInt(i)
if (r < k) {
sample(r) = elem
}
}
sample
}
/**
* Randomly sample k items from a sequence of length n.
* This method is preferrable if k is smaller than n.
*
* @return Shuffled sample of up to k items.
*/
def smallSample[T](sequence: IndexedSeq[T], k: Int, rng: Random): ArrayBuffer[T] = {
val n = sequence.length
var i = 0
val chosenIndices = mutable.Set[Int]()
while (i < n && i < k) {
val r = rng.nextInt(n)
if (!chosenIndices.contains(r)) {
chosenIndices += r
i += 1
}
}
val sample = new ArrayBuffer[T](k)
chosenIndices.foreach(r => sample.append(sequence(r)))
sample
}
/**
* In-place Knuth-Fisher-Yates shuffle.
*
* See https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
*/
def shuffle[T](arr: ArrayBuffer[T], rng: Random): Unit = {
for (i <- arr.length - 1 to 1 by -1) {
val r = rng.nextInt(i + 1)
val tmp = arr(i)
arr(i) = arr(r)
arr(r) = tmp
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment