Skip to content

Instantly share code, notes, and snippets.

@samklr
Created August 2, 2016 15:40
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samklr/bd11ebc6f8f696ef0583e86105a56d9d to your computer and use it in GitHub Desktop.
Save samklr/bd11ebc6f8f696ef0583e86105a56d9d to your computer and use it in GitHub Desktop.
Spark Accumulator Metrics
import scala.collection.mutable.Map
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
import org.apache.spark.scheduler.{SparkListenerStageCompleted, SparkListener}
import org.apache.spark.SparkContext._
/**
* just print out the values for all accumulators from the stage.
* you will only get updates from *named* accumulators, though
*/
object PrintingAccumulatorListener extends SparkListener {
override def onStageCompleted(event: SparkListenerStageCompleted): Unit = {
println("Stage accumulator values:")
event.stageInfo.accumulables.foreach{ case(id, accInfo) =>
println(s"$id:${accInfo.name}:${accInfo.value}")
}
}
}
/**
* can run aribtrary callback for accumulators when a stage completes.
*/
class FancyAccumulatorListener(sc: SparkContext) extends SparkListener {
val accumulatorsToPrevValue: Map[Accumulator[_], Any] = Map()
val accumulatorsToCallback: Map[Accumulator[_], _ => Unit] = Map()
//really you would want to add in more variants of this function -- accumulable,
// accumulableCollection, and all allowing a name as well
def accumulator[T](
initialValue: T,
callback: T => Unit
)(implicit param: AccumulatorParam[T]): Accumulator[T] = {
val acc = sc.accumulator(initialValue)(param)
accumulatorsToPrevValue(acc) = initialValue
accumulatorsToCallback(acc) = callback
acc
}
override def onStageCompleted(event: SparkListenerStageCompleted): Unit = {
// ignore the stageInfo, b/c
// (a) accumulableInfo only gives us the toString of the accumulator and
// (b) we also want info on unnamed accumulators
// loop through all the accumulators we know about, and see if they were updated.
// if so, run their callback.
// Note that this is assuming there aren't *too* many accumulators, and equals is cheap
accumulatorsToPrevValue.foreach{case (acc, prevValue) =>
if(acc.value != prevValue) {
accumulatorsToCallback(acc).asInstanceOf[Any => Unit](acc.value)
accumulatorsToPrevValue(acc) = acc.value
}
}
}
}
object ExampleProgram {
def runSimplePrint(sc: SparkContext): Unit = {
sc.addSparkListener(PrintingAccumulatorListener)
val data = sc.parallelize(1 to 100)
val namedAcc = sc.accumulator(0l, "my accumulator")
val unnamedAcc = sc.accumulator(0l)
data.foreach{x =>
namedAcc += x
unnamedAcc += x
}
//after this is run, you'll see an update for the named accumulator printed out, but the unnamed one is ignored by the listener
}
def runFancy(sc: SparkContext): Unit = {
val accRegister = new FancyAccumulatorListener(sc)
sc.addSparkListener(accRegister)
val data = sc.parallelize(1 to 100)
val acc1 = accRegister.accumulator(0l, {x: Long =>
println("accumulator #1 has been updated to value " + x)
})
val acc2 = accRegister.accumulator(0l, {x: Long =>
println("Numero dos updated! now its " + x)
})
println("running stage 1")
data.foreach{x =>
acc1 += x
}
println("done w/ stage 1, sleeping")
//we should see the update for acc1
Thread.sleep(5000) //just to make it clear what is going on w/ the listener
println("running stage 2")
data.foreach { x =>
acc2 += x
}
//now we'll see the update for acc2
}
}
// counts errors, and keep one error record for debugging
// WARNING: b/c this uses accumulators, the semantics around counting are *extremely* confusing
// if the RDD ever gets recomputed, do to shared lineage, cache eviction, or stage retries.
// use with caution.
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
class ErrorTracker[T] private(val name: String) extends Serializable {
var errorCounts = 0L
var successCounts = 0L
var errorSample: Option[T] = None
def ok(): Unit = successCounts += 1
def error(t: T): Unit = {
errorCounts += 1
errorSample = Some(t)
}
override def toString(): String = {
val total = errorCounts + successCounts
val frac = errorCounts.toDouble / total
f"$errorCounts%d errors / $total%d total ($frac%2.2f). " +
s"${errorSample.map{e => s"One random error: $e"}.getOrElse("")}"
}
}
object ErrorTracker {
def apply[T](name: String, sc: SparkContext): Accumulator[ErrorTracker[T]] = {
sc.accumulator(new ErrorTracker[T](name), name)(new ErrorTrackerAccumulator[T])
}
}
private class ErrorTrackerAccumulator[T] extends AccumulatorParam[ErrorTracker[T]] {
override def addInPlace(r1: ErrorTracker[T], r2: ErrorTracker[T]): ErrorTracker[T] = {
r1.errorCounts += r2.errorCounts
r1.successCounts += r2.successCounts
if (r1.errorSample.isEmpty) {
r1.errorSample = r2.errorSample
}
r1
}
override def zero(initialValue: ErrorTracker[T]): ErrorTracker[T] = initialValue
}
// uses algebird to help define accumulators over complex types.
// saves you the headache of defining lots of accumulator params
// https://github.com/twitter/algebird
// WARNING -- this makes it really tempting to do things that are NOT scalable.
// accumulators get merged on the *driver*, so if you make big accumulators or have
// a lot of tasks, merging the accumulators on the driver will become a bottleneck
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
import com.twitter.algebird._
import com.twitter.algebird.Operators._
class MonoidAccumulator[T: Monoid](t: T) extends AccumulatorParam[T] {
val monoid: Monoid[T] = implicitly[Monoid[T]]
def addInPlace(r1: T, r2: T): T = monoid.plus(r1, r2)
def zero(initialValue: T): T = monoid.zero
}
val data = sc.parallelize(1 to 100)
// algebird lets you easily define different ways of merging things. We'll have one
// map that sums the values ...
// (NB: algebird requires these to be immutable Maps)
val summingMapZero = Map[Int,Int]()
val summingMapAcc = sc.accumulator(summingMapZero)(new MonoidAccumulator(summingMapZero))
// ... and another one that take the max of the values
val maxMapZero = Map[Int, Max[Int]]()
val maxMapAcc = sc.accumulator(maxMapZero)(new MonoidAccumulator(maxMapZero))
data.foreach{x =>
summingMapAcc += Map((x%10) -> x)
maxMapAcc += Map((x%10) -> Max(x))
}
println("sums:")
summingMapAcc.value.foreach{println}
println("maxs:")
maxMapAcc.value.foreach{println}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment