Skip to content

Instantly share code, notes, and snippets.

Last active February 3, 2023 11:50
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save anish749/6a815ed281f538068a0d3a20ca9044fa to your computer and use it in GitHub Desktop.
Save anish749/6a815ed281f538068a0d3a20ca9044fa to your computer and use it in GitHub Desktop.
Spark UDAF to calculate the most common element in a column or the Statistical Mode for a given column. Written and test in Spark 2.1.0
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scalaz.Scalaz._
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
* Usage:
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
* Reference:
* Created by anish on 26/05/17.
class MostCommonValue extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Map[String, Long]](0) |+| Map(input.getAs[String](0) -> 1L)
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Map[String, Long]](0) |+| buffer2.getAs[Map[String, Long]](0)
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
* Usage:
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
* Reference:
* This version doesn't use the ScalaZ library
* Created by anish on 26/05/17.
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val inpString = input.getAs[String](0)
val existingMap = buffer.getAs[Map[String, Long]](0)
buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L)
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[String, Long]](0)
val map2 = buffer2.getAs[Map[String, Long]](0)
buffer1(0) = map1 ++{ case (k,v) => k -> (v + map1.getOrElse(k,0L)) }
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
Copy link

ronenlh commented Aug 3, 2022

Thanks for the feedback. I'm not familiar with Scala and only wrote it this way in lieu of a PySpark API for UDAF.
This is the deprecation warning I got from your original implementation (in case others look it up):

warning: class UserDefinedAggregateFunction in package expressions is deprecated (since 3.0.0): Aggregator[IN, BUF, OUT] should now be registered as a UDF via the functions.udaf(agg) method.
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction

I had to wrap Map[String, Long] in an object for Encoders.product to work.
I filtered empty string as I read that it was the safest way to filter out null. I had an issue with maxBy on null in some cases.
I edited the code to rename "Mode" into something else and removed the filters, leaving out the minimal implementation.
I'm not familiar with generics in Scala.

Copy link

Using new aggregator getting below exception for Spark3. Kindly help me to resolve the issue.

22/10/23 06:28:33 ERROR TaskSetManager: Task 3 in stage 0.0 failed 4 times; aborting job
Exception in thread "main" org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 0.0 failed 4 times, most recent failure: Lost task 3.3 in stage 0.0 (TID 22,, executor 1): java.lang.RuntimeException: Error while encoding: java.lang.RuntimeException: Cannot use null as map key!
externalmaptocatalyst(lambdavariable(ExternalMapToCatalyst_key, ObjectType(class java.lang.String), true, -1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, lambdavariable(ExternalMapToCatalyst_key, ObjectType(class java.lang.String), true, -1), true, false), lambdavariable(ExternalMapToCatalyst_value, LongType, false, -2), lambdavariable(ExternalMapToCatalyst_value, LongType, false, -2), knownnotnull(assertnotnull(input[0, com.verizon.ZeroEsDataParser.utils.FrenquencyMap, true])).frequencyMap) AS frequencyMap#2320
at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:215)
at org.apache.spark.sql.execution.aggregate.ScalaAggregator.serialize(udaf.scala:509)
at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.serializeAggregateBufferInPlace(interfaces.scala:591)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.$anonfun$dumpToExternalSorter$3(ObjectAggregationMap.scala:89)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.$anonfun$dumpToExternalSorter$3$adapted(ObjectAggregationMap.scala:87)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.dumpToExternalSorter(ObjectAggregationMap.scala:87)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:178)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.(ObjectAggregationIterator.scala:78)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2(ObjectHashAggregateExec.scala:129)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2$adapted(ObjectHashAggregateExec.scala:107)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:859)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:859)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:444)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
at org.apache.spark.executor.Executor$
at java.util.concurrent.ThreadPoolExecutor.runWorker(
at java.util.concurrent.ThreadPoolExecutor$

Copy link

Thanks @ronenlh for that code. Following up on @Chandraprabu, depending on the use case, we can get rid of the null or NullPointerException by treating null as a string like

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

case class FrequencyMap(frequencyMap: Map[String, Long])

object MostCommonValue extends Aggregator[String, FrequencyMap, String] {

  def zero: FrequencyMap = FrequencyMap(Map[String, Long]())

  def reduce(buffer: FrequencyMap, input: String): FrequencyMap = {
    input match {
      // if the input is null, create a key "null". This will create a key "null". Adjust this part if you don't want null to be the most frequent candidate. 
      case null => buffer.frequencyMap += (
        if (buffer.frequencyMap.contains("null"))
          "null" -> (buffer.frequencyMap("null") + 1)
          "null" -> 1L
      case _ => buffer.frequencyMap += (
        if (buffer.frequencyMap.contains(input))
          input -> (buffer.frequencyMap(input) + 1)
          input -> 1L


  def merge(b1: FrequencyMap, b2: FrequencyMap): FrequencyMap = {
    b1.frequencyMap ++={ case (k,v) => k -> (v + b1.frequencyMap.getOrElse(k, 0L)) }

  def finish(buffer: FrequencyMap): String = buffer.frequencyMap.maxBy(_._2)._1

  def bufferEncoder: Encoder[FrequencyMap] = Encoders.product
  def outputEncoder: Encoder[String] = Encoders.STRING

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment