Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Spark UDAF to calculate the most common element in a column or the Statistical Mode for a given column. Written and test in Spark 2.1.0
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scalaz.Scalaz._
/**
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
*
* Usage:
*
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
*
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
*
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
*
* Created by anish on 26/05/17.
*/
class MostCommonValue extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Map[String, Long]](0) |+| Map(input.getAs[String](0) -> 1L)
}
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Map[String, Long]](0) |+| buffer2.getAs[Map[String, Long]](0)
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
}
}
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
*
* Usage:
*
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
*
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
*
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
*
* This version doesn't use the ScalaZ library
*
* Created by anish on 26/05/17.
*/
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val inpString = input.getAs[String](0)
val existingMap = buffer.getAs[Map[String, Long]](0)
buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L)
}
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[String, Long]](0)
val map2 = buffer2.getAs[Map[String, Long]](0)
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k,0L)) }
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
}
}
@samuelroberth

This comment has been minimized.

Copy link

commented Sep 11, 2017

Is there a way to do this without Scalaz?

@HosniAkremi

This comment has been minimized.

Copy link

commented Oct 7, 2017

Hello Anish,
This UADF is amazing and I'm actually working on a project needing the implementation of the same.
My question is there anyway to do it with Spark 1.6 and without Scalaz?

I appreciate you prompt answer :)

@anish749

This comment has been minimized.

Copy link
Owner Author

commented Oct 13, 2017

I added a file with an implementation without using Scalaz. Hope that helps. I didn't test this with Spark 1.6, but I believe it will with Spark 1.6 as well. Let me know if you are facing problems.

Permalink to file: https://gist.github.com/anish749/6a815ed281f538068a0d3a20ca9044fa#file-mostcommonvalue_noscalaz-scala

@siddhartha-chandra

This comment has been minimized.

Copy link

commented May 23, 2019

good stuff Anish! Wish I had stumbled on this earlier, would have saved me quite some time to write my own UDAF. Did you try extending it to make it work for multiple types?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.