Skip to content

Instantly share code, notes, and snippets.

@darrenxyli
Forked from hsleep/HashMapParam.scala
Created November 28, 2015 02:27
Show Gist options
  • Save darrenxyli/acffe708bef47cedc903 to your computer and use it in GitHub Desktop.
Save darrenxyli/acffe708bef47cedc903 to your computer and use it in GitHub Desktop.
HashMapParam
/**
* Created by hsleep on 14. 12. 22..
*/
import org.apache.spark.{ AccumulableParam, SparkConf }
import org.apache.spark.serializer.JavaSerializer
import scala.collection.mutable.{ HashMap => MutableHashMap }
/*
* Allows a mutable HashMap[String, Int] to be used as an accumulator in Spark.
* Whenever we try to put (k, v2) into an accumulator that already contains (k, v1), the result
* will be a HashMap containing (k, v1 + v2).
*
* Would have been nice to extend GrowableAccumulableParam instead of redefining everything, but it's
* private to the spark package.
*/
object HashMapParam extends AccumulableParam[MutableHashMap[String, Int], (String, Int)] {
def addAccumulator(acc: MutableHashMap[String, Int], elem: (String, Int)): MutableHashMap[String, Int] = {
val (k1, v1) = elem
acc += acc.find(_._1 == k1).map {
case (k2, v2) => k2 -> (v1 + v2)
}.getOrElse(elem)
acc
}
/*
* This method is allowed to modify and return the first value for efficiency.
*
* @see org.apache.spark.GrowableAccumulableParam.addInPlace(r1: R, r2: R): R
*/
def addInPlace(acc1: MutableHashMap[String, Int], acc2: MutableHashMap[String, Int]): MutableHashMap[String, Int] = {
acc2.foreach(elem => addAccumulator(acc1, elem))
acc1
}
/*
* @see org.apache.spark.GrowableAccumulableParam.zero(initialValue: R): R
*/
def zero(initialValue: MutableHashMap[String, Int]): MutableHashMap[String, Int] = {
val ser = new JavaSerializer(new SparkConf(false)).newInstance()
val copy = ser.deserialize[MutableHashMap[String, Int]](ser.serialize(initialValue))
copy.clear()
copy
}
}
test("hashmap accumulator") {
val seq = Seq("a", "b", "c", "a")
sc = new SparkContext("local[*]", "test")
val mapAcc = sc.accumulable(new mutable.HashMap[String, Int])(HashMapParam)
sc.parallelize(seq).foreach(x => mapAcc += (x, 1))
val map = mapAcc.value
assert(map("a") == 2)
resetSparkContext()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment