Skip to content

Instantly share code, notes, and snippets.

@squito
Last active March 15, 2019 06:34
Show Gist options
  • Save squito/2f7cc02c313e4c9e7df4 to your computer and use it in GitHub Desktop.
Save squito/2f7cc02c313e4c9e7df4 to your computer and use it in GitHub Desktop.
Accumulator Examples
import scala.collection.mutable.Map
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
import org.apache.spark.scheduler.{SparkListenerStageCompleted, SparkListener}
import org.apache.spark.SparkContext._
/**
* just print out the values for all accumulators from the stage.
* you will only get updates from *named* accumulators, though
*/
object PrintingAccumulatorListener extends SparkListener {
override def onStageCompleted(event: SparkListenerStageCompleted): Unit = {
println("Stage accumulator values:")
event.stageInfo.accumulables.foreach{ case(id, accInfo) =>
println(s"$id:${accInfo.name}:${accInfo.value}")
}
}
}
/**
* can run aribtrary callback for accumulators when a stage completes.
*/
class FancyAccumulatorListener(sc: SparkContext) extends SparkListener {
val accumulatorsToPrevValue: Map[Accumulator[_], Any] = Map()
val accumulatorsToCallback: Map[Accumulator[_], _ => Unit] = Map()
//really you would want to add in more variants of this function -- accumulable,
// accumulableCollection, and all allowing a name as well
def accumulator[T](
initialValue: T,
callback: T => Unit
)(implicit param: AccumulatorParam[T]): Accumulator[T] = {
val acc = sc.accumulator(initialValue)(param)
accumulatorsToPrevValue(acc) = initialValue
accumulatorsToCallback(acc) = callback
acc
}
override def onStageCompleted(event: SparkListenerStageCompleted): Unit = {
// ignore the stageInfo, b/c
// (a) accumulableInfo only gives us the toString of the accumulator and
// (b) we also want info on unnamed accumulators
// loop through all the accumulators we know about, and see if they were updated.
// if so, run their callback.
// Note that this is assuming there aren't *too* many accumulators, and equals is cheap
accumulatorsToPrevValue.foreach{case (acc, prevValue) =>
if(acc.value != prevValue) {
accumulatorsToCallback(acc).asInstanceOf[Any => Unit](acc.value)
accumulatorsToPrevValue(acc) = acc.value
}
}
}
}
object ExampleProgram {
def runSimplePrint(sc: SparkContext): Unit = {
sc.addSparkListener(PrintingAccumulatorListener)
val data = sc.parallelize(1 to 100)
val namedAcc = sc.accumulator(0l, "my accumulator")
val unnamedAcc = sc.accumulator(0l)
data.foreach{x =>
namedAcc += x
unnamedAcc += x
}
//after this is run, you'll see an update for the named accumulator printed out, but the unnamed one is ignored by the listener
}
def runFancy(sc: SparkContext): Unit = {
val accRegister = new FancyAccumulatorListener(sc)
sc.addSparkListener(accRegister)
val data = sc.parallelize(1 to 100)
val acc1 = accRegister.accumulator(0l, {x: Long =>
println("accumulator #1 has been updated to value " + x)
})
val acc2 = accRegister.accumulator(0l, {x: Long =>
println("Numero dos updated! now its " + x)
})
println("running stage 1")
data.foreach{x =>
acc1 += x
}
println("done w/ stage 1, sleeping")
//we should see the update for acc1
Thread.sleep(5000) //just to make it clear what is going on w/ the listener
println("running stage 2")
data.foreach { x =>
acc2 += x
}
//now we'll see the update for acc2
}
}
// counts errors, and keep one error record for debugging
// WARNING: b/c this uses accumulators, the semantics around counting are *extremely* confusing
// if the RDD ever gets recomputed, do to shared lineage, cache eviction, or stage retries.
// use with caution.
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
class ErrorTracker[T] private(val name: String) extends Serializable {
var errorCounts = 0L
var successCounts = 0L
var errorSample: Option[T] = None
def ok(): Unit = successCounts += 1
def error(t: T): Unit = {
errorCounts += 1
errorSample = Some(t)
}
override def toString(): String = {
val total = errorCounts + successCounts
val frac = errorCounts.toDouble / total
f"$errorCounts%d errors / $total%d total ($frac%2.2f). " +
s"${errorSample.map{e => s"One random error: $e"}.getOrElse("")}"
}
}
object ErrorTracker {
def apply[T](name: String, sc: SparkContext): Accumulator[ErrorTracker[T]] = {
sc.accumulator(new ErrorTracker[T](name), name)(new ErrorTrackerAccumulator[T])
}
}
private class ErrorTrackerAccumulator[T] extends AccumulatorParam[ErrorTracker[T]] {
override def addInPlace(r1: ErrorTracker[T], r2: ErrorTracker[T]): ErrorTracker[T] = {
r1.errorCounts += r2.errorCounts
r1.successCounts += r2.successCounts
if (r1.errorSample.isEmpty) {
r1.errorSample = r2.errorSample
}
r1
}
override def zero(initialValue: ErrorTracker[T]): ErrorTracker[T] = initialValue
}
// uses algebird to help define accumulators over complex types.
// saves you the headache of defining lots of accumulator params
// https://github.com/twitter/algebird
// WARNING -- this makes it really tempting to do things that are NOT scalable.
// accumulators get merged on the *driver*, so if you make big accumulators or have
// a lot of tasks, merging the accumulators on the driver will become a bottleneck
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
import com.twitter.algebird._
import com.twitter.algebird.Operators._
class MonoidAccumulator[T: Monoid](t: T) extends AccumulatorParam[T] {
val monoid: Monoid[T] = implicitly[Monoid[T]]
def addInPlace(r1: T, r2: T): T = monoid.plus(r1, r2)
def zero(initialValue: T): T = monoid.zero
}
val data = sc.parallelize(1 to 100)
// algebird lets you easily define different ways of merging things. We'll have one
// map that sums the values ...
// (NB: algebird requires these to be immutable Maps)
val summingMapZero = Map[Int,Int]()
val summingMapAcc = sc.accumulator(summingMapZero)(new MonoidAccumulator(summingMapZero))
// ... and another one that take the max of the values
val maxMapZero = Map[Int, Max[Int]]()
val maxMapAcc = sc.accumulator(maxMapZero)(new MonoidAccumulator(maxMapZero))
data.foreach{x =>
summingMapAcc += Map((x%10) -> x)
maxMapAcc += Map((x%10) -> Max(x))
}
println("sums:")
summingMapAcc.value.foreach{println}
println("maxs:")
maxMapAcc.value.foreach{println}
@mmartsen
Copy link

Unfortunately PrintingAccumulatorListener will work in spark-shell only. Seems StageInfo does not contain user defined accumulators in other modes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment