Skip to content

Instantly share code, notes, and snippets.

@animeshtrivedi
Created May 4, 2017 15:26
Show Gist options
  • Save animeshtrivedi/81eeb2e9b4af9bfeb1de199fb1dfcea3 to your computer and use it in GitHub Desktop.
Save animeshtrivedi/81eeb2e9b4af9bfeb1de199fb1dfcea3 to your computer and use it in GitHub Desktop.
Spark shuffle benchmark with variations of the groupBy test
import java.util.Random
import org.apache.spark.SparkContext
object GroupByTest {
def test(numMappers:Int = 10, numKVPairs:Int = 100, valSize:Int = 1024, numReducers:Int = 10) {
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
val arr1 = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
}
arr1
}.cache()
// Enforce that everything has been calculated and in cache
pairs1.count()
println(pairs1.groupByKey(numReducers).count())
}
def testDS(numMappers:Int = 10, numKVPairs:Int = 100, valSize:Int = 1024, numReducers:Int = 10) {
import spark.implicits._
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
val arr1 = new Array[(Int, Array[Byte])](numKVPairs)
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
for (i <- 0 until numKVPairs) {
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
}
arr1
}.cache().toDS()
// this cache ensures that we get the same value for rest of the calculation
// if we don't do this, we get new value everytime
println(" ----------> " + pairs1.count)
val s = System.nanoTime()
val gb = pairs1.groupByKey(k => k._1)
val rb = gb.reduceGroups((v1, v2) => { (v2._1, Seq((v1._2.length + v2._2.length).toByte).toArray)})
val entries = rb.count
val end = System.nanoTime() - s
println(" Execution time: " + (end / 1000000) + " msecs, entries: " + entries )
}
def testLargeKey(numMappers:Int = 10, numKeys:Int = 100, keySize:Int = 1024, numReducers:Int = 10) {
import spark.implicits._
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random(System.nanoTime())
val arr1 = new Array[Array[Byte]] (numKeys)
for (i <- 0 until numKeys) {
val key = new Array[Byte](keySize)
ranGen.nextBytes(key)
arr1(i) = key
}
arr1
}.cache().toDS()
// this cache ensures that we get the same value for rest of the calculation
// if we don't do this, we get new value everytime
println(pairs1.count)
val s = System.nanoTime()
val gb = pairs1.groupByKey(k => k) // there is only one key
val rb = gb.reduceGroups((v1, v2) => {Seq((v1.length + v2.length).toByte).toArray})
val entries = rb.count
val end = System.nanoTime() - s
println(" Execution time: " + (end / 1000000) + " msecs, entries: " + entries )
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment