Skip to content

Instantly share code, notes, and snippets.

@ppanyukov
Last active April 23, 2020 16:39
Show Gist options
  • Save ppanyukov/253d251a16fbb660f225fb425d32206a to your computer and use it in GitHub Desktop.
Save ppanyukov/253d251a16fbb660f225fb425d32206a to your computer and use it in GitHub Desktop.
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
/**
* aggregate_by_key_example demonstrates how to use aggregateByKey to sum up
* values by key and also keep the values themselves.
*
* Given the input:
*
* Seq(("001", 1), ("001", 2), ("001", 3), ("002", 0), ("002", 7))
*
* We want output:
*
* (001,1,6)
* (001,2,6)
* (001,3,6)
* (002,0,7)
* (002,7,7)
*
* As asked on SO: http://stackoverflow.com/questions/36455419/spark-reducebykey-and-keep-other-columns
*
* Why use aggregateByKey instead of groupByKey? As from spark api docs:
* (http://spark.apache.org/docs/1.6.1/api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions)
*
* > Note: This operation may be very expensive. If you are grouping in order
* > to perform an aggregation (such as a sum or average) over each key,
* > using PairRDDFunctions.aggregateByKey or PairRDDFunctions.reduceByKey
* > will provide much better performance.
*
* Of course we are not doing "simple aggregation such as sum" here so the
* performance benefits of this approach vs groupByKey may not be present.
* Obviously benchmarking both approaches on real data is required.
*
**/
def aggregate_by_key_example(sc: SparkContext) = {
// The input as given by OP here: http://stackoverflow.com/questions/36455419/spark-reducebykey-and-keep-other-columns
val table = sc.parallelize(Seq(("001", 1), ("001", 2), ("001", 3), ("002", 0), ("002", 7)))
// zero is initial value into which we will aggregate things.
// The second element is the sum.
// The first element is the list of values which contributed to this sum.
val zero = (List.empty[Int], 0)
// sequencer will receive an accumulator and the value.
// The accumulator will be reset for each key to 'zero'.
// In this sequencer we add value to the sum and append to the list because
// we want to keep both.
// This can be thought of as "map" stage in classic map/reduce.
def sequencer(acc: (List[Int], Int), value: Int) = {
val (values, sum) = acc
(value :: values, sum + value)
}
// combiner combines two lists and sums into one.
// The reason for this is the sequencer may run in different partitions
// and thus produce partial results. This step combines those partials into
// one final result.
// This step can be thought of as "reduce" stage in classic map/reduce.
def combiner(left: (List[Int], Int), right: (List[Int], Int)) = {
(left._1 ++ right._1, left._2 + right._2)
}
// wiring it all together.
// Note the type of result it produces:
// Each key will have a list of values which contributed to the sum, sum the sum itself.
val result: RDD[(String, (List[Int], Int))] = table.aggregateByKey(zero)(sequencer, combiner)
// To turn this to a flat list and print, use flatMap to produce:
// (key, value, sum)
val flatResult: RDD[(String, Int, Int)] = result.flatMap(result => {
val (key, (values, sum)) = result
for (value <- values) yield (key, value, sum)
})
// collect and print
flatResult.collect().foreach(println)
}
// in spark-shell do aggregate_by_key_example(sc)
//
// This should produce result:
// (001,1,6)
// (001,2,6)
// (001,3,6)
// (002,0,7)
// (002,7,7)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment