Skip to content

Instantly share code, notes, and snippets.

@tzachz
Last active January 26, 2023 04:31
Show Gist options
  • Save tzachz/c976a1080b6379ef861c142c16f1364a to your computer and use it in GitHub Desktop.
Save tzachz/c976a1080b6379ef861c142c16f1364a to your computer and use it in GitHub Desktop.
Apache Spark UserDefinedAggregateFunction combining maps
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, Row, SQLContext}
/***
* UDAF combining maps, overriding any duplicate key with "latest" value
* @param keyType DataType of Map key
* @param valueType DataType of Value key
* @param merge function to merge values of identical keys
* @tparam K key type
* @tparam V value type
*/
class CombineMaps[K, V](keyType: DataType, valueType: DataType, merge: (V, V) => V) extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("map", dataType)
override def bufferSchema: StructType = inputSchema
override def dataType: DataType = MapType(keyType, valueType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0, Map[K, V]())
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val map1 = buffer.getAs[Map[K, V]](0)
val map2 = input.getAs[Map[K, V]](0)
val result = map1 ++ map2.map { case (k,v) => k -> map1.get(k).map(merge(v, _)).getOrElse(v) }
buffer.update(0, result)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = update(buffer1, buffer2)
override def evaluate(buffer: Row): Any = buffer.getAs[Map[K, V]](0)
}
object Example {
def main(args: Array[String]): Unit = {
import org.apache.spark.sql.functions._
val sc: SparkContext = new SparkContext("local", "test")
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val input = sc.parallelize(Seq(
(1, Map("John" -> 12.5, "Alice" -> 5.5)),
(1, Map("Jim" -> 16.5, "Alice" -> 4.0)),
(2, Map("John" -> 13.5, "Jim" -> 2.5))
)).toDF("id", "scoreMap")
// instantiate a CombineMaps with the relevant types:
val combineMaps = new CombineMaps[String, Double](StringType, DoubleType, _ + _)
// groupBy and aggregate
val result = input.groupBy("id").agg(combineMaps(col("scoreMap")))
result.show(truncate = false)
// +---+--------------------------------------------+
// |id |CombineMaps(scoreMap) |
// +---+--------------------------------------------+
// |1 |Map(John -> 12.5, Alice -> 9.5, Jim -> 16.5)|
// |2 |Map(John -> 13.5, Jim -> 2.5) |
// +---+--------------------------------------------+
}
}
@mrbrahman
Copy link

This is great! Thank you.

@fjavieralba
Copy link

man, this is great!!
Thanks for sharing.

@d3r1v3d
Copy link

d3r1v3d commented Apr 9, 2018

Little late to the party, but shouldn't evaluate use the generic, parameter types?

override def evaluate(buffer: Row): Any = buffer.getAs[Map[K, V]](0)

@tzachz
Copy link
Author

tzachz commented May 8, 2018

oops, @d3r1v3d - you're right! Thanks, fixed 👍

@bradleysmithllc
Copy link

Very nice example, thank you! I have a question, though. What purpose do the input and buffer schemas serve? I can't seem to get them to do anything. I had expected inputSchema to evaluate whether the correct columns and types were passed in, but that doesn't seem to be true.

@dedcode
Copy link

dedcode commented Jun 15, 2020

Some cells can be null, so you probably need to check for that using if (!input.isNullAt(0))
This was very helpful 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment