Last active
March 15, 2019 06:34
-
-
Save squito/2f7cc02c313e4c9e7df4 to your computer and use it in GitHub Desktop.
Accumulator Examples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.collection.mutable.Map | |
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext} | |
import org.apache.spark.scheduler.{SparkListenerStageCompleted, SparkListener} | |
import org.apache.spark.SparkContext._ | |
/** | |
* just print out the values for all accumulators from the stage. | |
* you will only get updates from *named* accumulators, though | |
*/ | |
object PrintingAccumulatorListener extends SparkListener { | |
override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { | |
println("Stage accumulator values:") | |
event.stageInfo.accumulables.foreach{ case(id, accInfo) => | |
println(s"$id:${accInfo.name}:${accInfo.value}") | |
} | |
} | |
} | |
/** | |
* can run aribtrary callback for accumulators when a stage completes. | |
*/ | |
class FancyAccumulatorListener(sc: SparkContext) extends SparkListener { | |
val accumulatorsToPrevValue: Map[Accumulator[_], Any] = Map() | |
val accumulatorsToCallback: Map[Accumulator[_], _ => Unit] = Map() | |
//really you would want to add in more variants of this function -- accumulable, | |
// accumulableCollection, and all allowing a name as well | |
def accumulator[T]( | |
initialValue: T, | |
callback: T => Unit | |
)(implicit param: AccumulatorParam[T]): Accumulator[T] = { | |
val acc = sc.accumulator(initialValue)(param) | |
accumulatorsToPrevValue(acc) = initialValue | |
accumulatorsToCallback(acc) = callback | |
acc | |
} | |
override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { | |
// ignore the stageInfo, b/c | |
// (a) accumulableInfo only gives us the toString of the accumulator and | |
// (b) we also want info on unnamed accumulators | |
// loop through all the accumulators we know about, and see if they were updated. | |
// if so, run their callback. | |
// Note that this is assuming there aren't *too* many accumulators, and equals is cheap | |
accumulatorsToPrevValue.foreach{case (acc, prevValue) => | |
if(acc.value != prevValue) { | |
accumulatorsToCallback(acc).asInstanceOf[Any => Unit](acc.value) | |
accumulatorsToPrevValue(acc) = acc.value | |
} | |
} | |
} | |
} | |
object ExampleProgram { | |
def runSimplePrint(sc: SparkContext): Unit = { | |
sc.addSparkListener(PrintingAccumulatorListener) | |
val data = sc.parallelize(1 to 100) | |
val namedAcc = sc.accumulator(0l, "my accumulator") | |
val unnamedAcc = sc.accumulator(0l) | |
data.foreach{x => | |
namedAcc += x | |
unnamedAcc += x | |
} | |
//after this is run, you'll see an update for the named accumulator printed out, but the unnamed one is ignored by the listener | |
} | |
def runFancy(sc: SparkContext): Unit = { | |
val accRegister = new FancyAccumulatorListener(sc) | |
sc.addSparkListener(accRegister) | |
val data = sc.parallelize(1 to 100) | |
val acc1 = accRegister.accumulator(0l, {x: Long => | |
println("accumulator #1 has been updated to value " + x) | |
}) | |
val acc2 = accRegister.accumulator(0l, {x: Long => | |
println("Numero dos updated! now its " + x) | |
}) | |
println("running stage 1") | |
data.foreach{x => | |
acc1 += x | |
} | |
println("done w/ stage 1, sleeping") | |
//we should see the update for acc1 | |
Thread.sleep(5000) //just to make it clear what is going on w/ the listener | |
println("running stage 2") | |
data.foreach { x => | |
acc2 += x | |
} | |
//now we'll see the update for acc2 | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// counts errors, and keep one error record for debugging | |
// WARNING: b/c this uses accumulators, the semantics around counting are *extremely* confusing | |
// if the RDD ever gets recomputed, do to shared lineage, cache eviction, or stage retries. | |
// use with caution. | |
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext} | |
class ErrorTracker[T] private(val name: String) extends Serializable { | |
var errorCounts = 0L | |
var successCounts = 0L | |
var errorSample: Option[T] = None | |
def ok(): Unit = successCounts += 1 | |
def error(t: T): Unit = { | |
errorCounts += 1 | |
errorSample = Some(t) | |
} | |
override def toString(): String = { | |
val total = errorCounts + successCounts | |
val frac = errorCounts.toDouble / total | |
f"$errorCounts%d errors / $total%d total ($frac%2.2f). " + | |
s"${errorSample.map{e => s"One random error: $e"}.getOrElse("")}" | |
} | |
} | |
object ErrorTracker { | |
def apply[T](name: String, sc: SparkContext): Accumulator[ErrorTracker[T]] = { | |
sc.accumulator(new ErrorTracker[T](name), name)(new ErrorTrackerAccumulator[T]) | |
} | |
} | |
private class ErrorTrackerAccumulator[T] extends AccumulatorParam[ErrorTracker[T]] { | |
override def addInPlace(r1: ErrorTracker[T], r2: ErrorTracker[T]): ErrorTracker[T] = { | |
r1.errorCounts += r2.errorCounts | |
r1.successCounts += r2.successCounts | |
if (r1.errorSample.isEmpty) { | |
r1.errorSample = r2.errorSample | |
} | |
r1 | |
} | |
override def zero(initialValue: ErrorTracker[T]): ErrorTracker[T] = initialValue | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// uses algebird to help define accumulators over complex types. | |
// saves you the headache of defining lots of accumulator params | |
// https://github.com/twitter/algebird | |
// WARNING -- this makes it really tempting to do things that are NOT scalable. | |
// accumulators get merged on the *driver*, so if you make big accumulators or have | |
// a lot of tasks, merging the accumulators on the driver will become a bottleneck | |
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext} | |
import com.twitter.algebird._ | |
import com.twitter.algebird.Operators._ | |
class MonoidAccumulator[T: Monoid](t: T) extends AccumulatorParam[T] { | |
val monoid: Monoid[T] = implicitly[Monoid[T]] | |
def addInPlace(r1: T, r2: T): T = monoid.plus(r1, r2) | |
def zero(initialValue: T): T = monoid.zero | |
} | |
val data = sc.parallelize(1 to 100) | |
// algebird lets you easily define different ways of merging things. We'll have one | |
// map that sums the values ... | |
// (NB: algebird requires these to be immutable Maps) | |
val summingMapZero = Map[Int,Int]() | |
val summingMapAcc = sc.accumulator(summingMapZero)(new MonoidAccumulator(summingMapZero)) | |
// ... and another one that take the max of the values | |
val maxMapZero = Map[Int, Max[Int]]() | |
val maxMapAcc = sc.accumulator(maxMapZero)(new MonoidAccumulator(maxMapZero)) | |
data.foreach{x => | |
summingMapAcc += Map((x%10) -> x) | |
maxMapAcc += Map((x%10) -> Max(x)) | |
} | |
println("sums:") | |
summingMapAcc.value.foreach{println} | |
println("maxs:") | |
maxMapAcc.value.foreach{println} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Unfortunately PrintingAccumulatorListener will work in spark-shell only. Seems StageInfo does not contain user defined accumulators in other modes