Skip to content

Instantly share code, notes, and snippets.

@szhem
Created May 4, 2018 11:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save szhem/52a26ada4bbeb1a3e762710adc3f94ef to your computer and use it in GitHub Desktop.
Save szhem/52a26ada4bbeb1a3e762710adc3f94ef to your computer and use it in GitHub Desktop.
Spark :: AccumulatorV2 vs AccumulableParam (V1)
package foo.bar
import java.{lang => jl}
import org.apache.spark.OpenAccumulatorContext
import org.apache.spark.util.LongAccumulator
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.SparkException
import org.apache.spark.TaskContext
import org.scalatest.BeforeAndAfterEach
import org.scalatest.Matchers
import org.scalatest.WordSpec
class AccumulatorsSpec extends WordSpec with BeforeAndAfterEach with Matchers {
var sc: SparkContext = _
override protected def beforeEach(): Unit = {
super.beforeEach()
sc = new SparkContext(
new SparkConf()
.setMaster("local[1,2]") // single threaded that can fail for 2 times
.setAppName(this.getClass.getSimpleName)
.set("spark.ui.enabled", "false")
.set("spark.logLineage", "true")
)
}
override protected def afterEach(): Unit = {
sc.stop()
super.afterEach()
}
"Accumulator V1" should {
"be able to be registered twice" in {
import org.apache.spark.AccumulatorParam._
val data = 1 to 10
val rdd = sc.makeRDD(data, 5)
val accP = IntAccumulatorParam
val acc1 = sc.accumulator(0, "acc1")(accP)
rdd.foreach(acc1.add(_))
acc1.value shouldEqual data.sum
OpenAccumulatorContext.names should contain ("acc1")
val acc2 = sc.accumulator(0, "acc2")(accP)
rdd.foreach(acc2.add(_))
acc2.value shouldEqual data.sum
OpenAccumulatorContext.names should contain ("acc2")
}
}
"Accumulator V2" should {
"be able to be registered twice" in {
val data = 1 to 10
val rdd = sc.makeRDD(data, 5)
val acc = new LongAccumulator
sc.register(acc, "acc1")
rdd.foreach(acc.add(_))
acc.value shouldEqual data.sum
OpenAccumulatorContext.names should contain ("acc1")
// the test will fail registering the same accumulator one more time until you uncomment the next line
// OpenAccumulatorContext.reset(acc)
sc.register(acc, "acc2")
rdd.foreach(acc.add(_))
acc.value shouldEqual data.sum
OpenAccumulatorContext.names should contain ("acc2")
}
}
}
package org.apache.spark
import org.apache.spark.util.AccumulatorContext
import org.apache.spark.util.AccumulatorV2
object OpenAccumulatorContext {
def accumulators: Map[Long, AccumulatorV2[_, _]] = {
(0L until AccumulatorContext.newId()).foldLeft(Map[Long, AccumulatorV2[_, _]]()) { (accMap, accNum) =>
val acc =
try {
AccumulatorContext.get(accNum)
} catch {
case _:IllegalStateException => None
}
acc match {
case Some(a) => accMap + (accNum -> a)
case _ => accMap
}
}
}
def names: Set[String] = accumulators.values.flatMap(_.name).toSet
def reset(acc: AccumulatorV2[_, _]): AccumulatorV2[_, _] = {
AccumulatorContext.remove(acc.id)
acc.reset()
acc.metadata = null
acc
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment