-
-
Save szhem/52a26ada4bbeb1a3e762710adc3f94ef to your computer and use it in GitHub Desktop.
Spark :: AccumulatorV2 vs AccumulableParam (V1)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package foo.bar | |
import java.{lang => jl} | |
import org.apache.spark.OpenAccumulatorContext | |
import org.apache.spark.util.LongAccumulator | |
import org.apache.spark.SparkConf | |
import org.apache.spark.SparkContext | |
import org.apache.spark.SparkException | |
import org.apache.spark.TaskContext | |
import org.scalatest.BeforeAndAfterEach | |
import org.scalatest.Matchers | |
import org.scalatest.WordSpec | |
class AccumulatorsSpec extends WordSpec with BeforeAndAfterEach with Matchers { | |
var sc: SparkContext = _ | |
override protected def beforeEach(): Unit = { | |
super.beforeEach() | |
sc = new SparkContext( | |
new SparkConf() | |
.setMaster("local[1,2]") // single threaded that can fail for 2 times | |
.setAppName(this.getClass.getSimpleName) | |
.set("spark.ui.enabled", "false") | |
.set("spark.logLineage", "true") | |
) | |
} | |
override protected def afterEach(): Unit = { | |
sc.stop() | |
super.afterEach() | |
} | |
"Accumulator V1" should { | |
"be able to be registered twice" in { | |
import org.apache.spark.AccumulatorParam._ | |
val data = 1 to 10 | |
val rdd = sc.makeRDD(data, 5) | |
val accP = IntAccumulatorParam | |
val acc1 = sc.accumulator(0, "acc1")(accP) | |
rdd.foreach(acc1.add(_)) | |
acc1.value shouldEqual data.sum | |
OpenAccumulatorContext.names should contain ("acc1") | |
val acc2 = sc.accumulator(0, "acc2")(accP) | |
rdd.foreach(acc2.add(_)) | |
acc2.value shouldEqual data.sum | |
OpenAccumulatorContext.names should contain ("acc2") | |
} | |
} | |
"Accumulator V2" should { | |
"be able to be registered twice" in { | |
val data = 1 to 10 | |
val rdd = sc.makeRDD(data, 5) | |
val acc = new LongAccumulator | |
sc.register(acc, "acc1") | |
rdd.foreach(acc.add(_)) | |
acc.value shouldEqual data.sum | |
OpenAccumulatorContext.names should contain ("acc1") | |
// the test will fail registering the same accumulator one more time until you uncomment the next line | |
// OpenAccumulatorContext.reset(acc) | |
sc.register(acc, "acc2") | |
rdd.foreach(acc.add(_)) | |
acc.value shouldEqual data.sum | |
OpenAccumulatorContext.names should contain ("acc2") | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark | |
import org.apache.spark.util.AccumulatorContext | |
import org.apache.spark.util.AccumulatorV2 | |
object OpenAccumulatorContext { | |
def accumulators: Map[Long, AccumulatorV2[_, _]] = { | |
(0L until AccumulatorContext.newId()).foldLeft(Map[Long, AccumulatorV2[_, _]]()) { (accMap, accNum) => | |
val acc = | |
try { | |
AccumulatorContext.get(accNum) | |
} catch { | |
case _:IllegalStateException => None | |
} | |
acc match { | |
case Some(a) => accMap + (accNum -> a) | |
case _ => accMap | |
} | |
} | |
} | |
def names: Set[String] = accumulators.values.flatMap(_.name).toSet | |
def reset(acc: AccumulatorV2[_, _]): AccumulatorV2[_, _] = { | |
AccumulatorContext.remove(acc.id) | |
acc.reset() | |
acc.metadata = null | |
acc | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment