Skip to content

Instantly share code, notes, and snippets.

@anish749
Last active February 3, 2023 11:50
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save anish749/6a815ed281f538068a0d3a20ca9044fa to your computer and use it in GitHub Desktop.
Save anish749/6a815ed281f538068a0d3a20ca9044fa to your computer and use it in GitHub Desktop.
Spark UDAF to calculate the most common element in a column or the Statistical Mode for a given column. Written and test in Spark 2.1.0
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scalaz.Scalaz._
/**
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
*
* Usage:
*
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
*
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
*
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
*
* Created by anish on 26/05/17.
*/
class MostCommonValue extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Map[String, Long]](0) |+| Map(input.getAs[String](0) -> 1L)
}
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Map[String, Long]](0) |+| buffer2.getAs[Map[String, Long]](0)
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
}
}
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
*
* Usage:
*
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
*
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
*
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
*
* This version doesn't use the ScalaZ library
*
* Created by anish on 26/05/17.
*/
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val inpString = input.getAs[String](0)
val existingMap = buffer.getAs[Map[String, Long]](0)
buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L)
}
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[String, Long]](0)
val map2 = buffer2.getAs[Map[String, Long]](0)
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k,0L)) }
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
}
}
@ronenlh
Copy link

ronenlh commented Aug 3, 2022

Thanks for the feedback. I'm not familiar with Scala and only wrote it this way in lieu of a PySpark API for UDAF.
This is the deprecation warning I got from your original implementation (in case others look it up):

warning: class UserDefinedAggregateFunction in package expressions is deprecated (since 3.0.0): Aggregator[IN, BUF, OUT] should now be registered as a UDF via the functions.udaf(agg) method.
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction

I had to wrap Map[String, Long] in an object for Encoders.product to work.
I filtered empty string as I read that it was the safest way to filter out null. I had an issue with maxBy on null in some cases.
I edited the code to rename "Mode" into something else and removed the filters, leaving out the minimal implementation.
I'm not familiar with generics in Scala.

@Chandraprabu
Copy link

Using new aggregator getting below exception for Spark3. Kindly help me to resolve the issue.

22/10/23 06:28:33 ERROR TaskSetManager: Task 3 in stage 0.0 failed 4 times; aborting job
Exception in thread "main" org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 0.0 failed 4 times, most recent failure: Lost task 3.3 in stage 0.0 (TID 22, 10.233.101.132, executor 1): java.lang.RuntimeException: Error while encoding: java.lang.RuntimeException: Cannot use null as map key!
externalmaptocatalyst(lambdavariable(ExternalMapToCatalyst_key, ObjectType(class java.lang.String), true, -1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, lambdavariable(ExternalMapToCatalyst_key, ObjectType(class java.lang.String), true, -1), true, false), lambdavariable(ExternalMapToCatalyst_value, LongType, false, -2), lambdavariable(ExternalMapToCatalyst_value, LongType, false, -2), knownnotnull(assertnotnull(input[0, com.verizon.ZeroEsDataParser.utils.FrenquencyMap, true])).frequencyMap) AS frequencyMap#2320
at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:215)
at org.apache.spark.sql.execution.aggregate.ScalaAggregator.serialize(udaf.scala:509)
at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.serializeAggregateBufferInPlace(interfaces.scala:591)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.$anonfun$dumpToExternalSorter$3(ObjectAggregationMap.scala:89)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.$anonfun$dumpToExternalSorter$3$adapted(ObjectAggregationMap.scala:87)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.dumpToExternalSorter(ObjectAggregationMap.scala:87)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:178)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.(ObjectAggregationIterator.scala:78)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2(ObjectHashAggregateExec.scala:129)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2$adapted(ObjectHashAggregateExec.scala:107)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:859)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:859)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
at org.apache.spark.scheduler.Task.run(Task.scala:127)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:444)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:447)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:750)

@kennyishihara
Copy link

Thanks @ronenlh for that code. Following up on @Chandraprabu, depending on the use case, we can get rid of the null or NullPointerException by treating null as a string like

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

case class FrequencyMap(frequencyMap: Map[String, Long])

object MostCommonValue extends Aggregator[String, FrequencyMap, String] {

  def zero: FrequencyMap = FrequencyMap(Map[String, Long]())

  def reduce(buffer: FrequencyMap, input: String): FrequencyMap = {
    input match {
      // if the input is null, create a key "null". This will create a key "null". Adjust this part if you don't want null to be the most frequent candidate. 
      case null => buffer.frequencyMap += (
        if (buffer.frequencyMap.contains("null"))
          "null" -> (buffer.frequencyMap("null") + 1)
        else
          "null" -> 1L
        )
      case _ => buffer.frequencyMap += (
        if (buffer.frequencyMap.contains(input))
          input -> (buffer.frequencyMap(input) + 1)
        else
          input -> 1L
        )
    }

    buffer
  }

  def merge(b1: FrequencyMap, b2: FrequencyMap): FrequencyMap = {
    b1.frequencyMap ++= b2.frequencyMap.map{ case (k,v) => k -> (v + b1.frequencyMap.getOrElse(k, 0L)) }
    b1
  }

  def finish(buffer: FrequencyMap): String = buffer.frequencyMap.maxBy(_._2)._1

  def bufferEncoder: Encoder[FrequencyMap] = Encoders.product
  def outputEncoder: Encoder[String] = Encoders.STRING
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment