Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Spark UDAF to calculate the most common element in a column or the Statistical Mode for a given column. Written and test in Spark 2.1.0
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scalaz.Scalaz._
/**
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
*
* Usage:
*
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
*
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
*
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
*
* Created by anish on 26/05/17.
*/
class MostCommonValue extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Map[String, Long]](0) |+| Map(input.getAs[String](0) -> 1L)
}
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Map[String, Long]](0) |+| buffer2.getAs[Map[String, Long]](0)
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
}
}
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
*
* Usage:
*
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
*
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
*
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
*
* This version doesn't use the ScalaZ library
*
* Created by anish on 26/05/17.
*/
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val inpString = input.getAs[String](0)
val existingMap = buffer.getAs[Map[String, Long]](0)
buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L)
}
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[String, Long]](0)
val map2 = buffer2.getAs[Map[String, Long]](0)
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k,0L)) }
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
}
}
@samuelroberth

This comment has been minimized.

Show comment Hide comment
@samuelroberth

samuelroberth Sep 11, 2017

Is there a way to do this without Scalaz?

Is there a way to do this without Scalaz?

@HosniAkremi

This comment has been minimized.

Show comment Hide comment
@HosniAkremi

HosniAkremi Oct 7, 2017

Hello Anish,
This UADF is amazing and I'm actually working on a project needing the implementation of the same.
My question is there anyway to do it with Spark 1.6 and without Scalaz?

I appreciate you prompt answer :)

Hello Anish,
This UADF is amazing and I'm actually working on a project needing the implementation of the same.
My question is there anyway to do it with Spark 1.6 and without Scalaz?

I appreciate you prompt answer :)

@anish749

This comment has been minimized.

Show comment Hide comment
@anish749

anish749 Oct 13, 2017

I added a file with an implementation without using Scalaz. Hope that helps. I didn't test this with Spark 1.6, but I believe it will with Spark 1.6 as well. Let me know if you are facing problems.

Permalink to file: https://gist.github.com/anish749/6a815ed281f538068a0d3a20ca9044fa#file-mostcommonvalue_noscalaz-scala

Owner

anish749 commented Oct 13, 2017

I added a file with an implementation without using Scalaz. Hope that helps. I didn't test this with Spark 1.6, but I believe it will with Spark 1.6 as well. Let me know if you are facing problems.

Permalink to file: https://gist.github.com/anish749/6a815ed281f538068a0d3a20ca9044fa#file-mostcommonvalue_noscalaz-scala

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment