Skip to content

Instantly share code, notes, and snippets.

@anish749
Last active February 3, 2023 11:50
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save anish749/6a815ed281f538068a0d3a20ca9044fa to your computer and use it in GitHub Desktop.
Save anish749/6a815ed281f538068a0d3a20ca9044fa to your computer and use it in GitHub Desktop.
Spark UDAF to calculate the most common element in a column or the Statistical Mode for a given column. Written and test in Spark 2.1.0
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scalaz.Scalaz._
/**
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
*
* Usage:
*
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
*
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
*
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
*
* Created by anish on 26/05/17.
*/
class MostCommonValue extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Map[String, Long]](0) |+| Map(input.getAs[String](0) -> 1L)
}
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Map[String, Long]](0) |+| buffer2.getAs[Map[String, Long]](0)
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
}
}
package org.anish.spark.mostcommonvalue
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
* these values together is considered as mode.
*
* Usage:
*
* DataFrame / DataSet DSL
* val mostCommonValue = new MostCommonValue
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
*
* Spark SQL:
* sqlContext.udf.register("mode", new MostCommonValue)
* %sql
* -- Use a group_by statement and call the UDAF.
* select group_id, mode(id) from table group by group_id
*
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
*
* This version doesn't use the ScalaZ library
*
* Created by anish on 26/05/17.
*/
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
// We use StringType, because Mode can also be meaningfully applied on nominal data
override def inputSchema: StructType =
StructType(StructField("value", StringType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
override def bufferSchema: StructType = StructType(
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
)
// This is the output type of your aggregation function.
override def dataType: DataType = StringType
override def deterministic: Boolean = true
// This is the initial value for the buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val inpString = input.getAs[String](0)
val existingMap = buffer.getAs[Map[String, Long]](0)
buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L)
}
// This is how you merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[String, Long]](0)
val map2 = buffer2.getAs[Map[String, Long]](0)
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k,0L)) }
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): String = {
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
}
}
@Developer4190
Copy link

Is there a way to do this without Scalaz?

@HosniAkremi
Copy link

Hello Anish,
This UADF is amazing and I'm actually working on a project needing the implementation of the same.
My question is there anyway to do it with Spark 1.6 and without Scalaz?

I appreciate you prompt answer :)

@anish749
Copy link
Author

I added a file with an implementation without using Scalaz. Hope that helps. I didn't test this with Spark 1.6, but I believe it will with Spark 1.6 as well. Let me know if you are facing problems.

Permalink to file: https://gist.github.com/anish749/6a815ed281f538068a0d3a20ca9044fa#file-mostcommonvalue_noscalaz-scala

@siddhartha-chandra
Copy link

good stuff Anish! Wish I had stumbled on this earlier, would have saved me quite some time to write my own UDAF. Did you try extending it to make it work for multiple types?

@ruloweb
Copy link

ruloweb commented Apr 23, 2020

There is a NullPointException, which is fixed by:

   override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
-    val inpString = input.getAs[String](0)
-    val existingMap = buffer.getAs[Map[String, Long]](0)
-    buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L)
+    if (!input.isNullAt(0)) {
+      val inpString = input.getAs[String](0)
+      val existingMap = buffer.getAs[Map[String, Long]](0)
+      buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L)
+    }
   }

@ronenlh
Copy link

ronenlh commented Aug 2, 2022

this is the new syntax, using Aggregator[-IN, BUF, OUT] and spark.udf.register

docs: https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html

thanks

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions

case class FrenquencyMap(var frequencyMap: Map[String, Long])

object MostCommonValue extends Aggregator[String, FrenquencyMap, String] {

  def zero: FrenquencyMap = FrenquencyMap(Map[String, Long]())

  def reduce(buffer: FrenquencyMap, input: String): FrenquencyMap = {
    buffer.frequencyMap += (
      if (buffer.frequencyMap.contains(input))
        input -> (buffer.frequencyMap(input) + 1)
      else
        input -> 1L
    )
    
    buffer
  }

  def merge(b1: FrenquencyMap, b2: FrenquencyMap): FrenquencyMap = {
    b1.frequencyMap ++= b2.frequencyMap.map{ case (k,v) => k -> (v + b1.frequencyMap.getOrElse(k, 0L)) }
    b1
  }

  def finish(buffer: FrenquencyMap): String = buffer.frequencyMap.maxBy(_._2)._1

  def bufferEncoder: Encoder[FrenquencyMap] = Encoders.product
  def outputEncoder: Encoder[String] = Encoders.STRING
}

spark.udf.register("mode", functions.udaf(MostCommonValue))

val result = spark.sql("SELECT foo, mode(bar) as most_common_value FROM baz group by foo")

@anish749
Copy link
Author

anish749 commented Aug 2, 2022

this is much nicer with the new Aggregator api, a couple of gotchas about the above implementation..

  • using var might be dangerous for this use case. I am not sure if Spark provides guarantees around what the map will contain in case an executor fails and the task is retried. It is hence advisable to use val and return a new object in the reduce and merge fn.
  • The input and output types can be generic with the new api, not necessarily String.
  • An empty String shouldn't be filtered out, its a valid value of String
  • Mode isn't necessarily 1 value, but can be more than 1, if all have the same number of elements.
  • If frequency map is empty the result shouldn't be an empty string, but the fact that mode will be undefined.

@ronenlh
Copy link

ronenlh commented Aug 3, 2022

Thanks for the feedback. I'm not familiar with Scala and only wrote it this way in lieu of a PySpark API for UDAF.
This is the deprecation warning I got from your original implementation (in case others look it up):

warning: class UserDefinedAggregateFunction in package expressions is deprecated (since 3.0.0): Aggregator[IN, BUF, OUT] should now be registered as a UDF via the functions.udaf(agg) method.
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction

I had to wrap Map[String, Long] in an object for Encoders.product to work.
I filtered empty string as I read that it was the safest way to filter out null. I had an issue with maxBy on null in some cases.
I edited the code to rename "Mode" into something else and removed the filters, leaving out the minimal implementation.
I'm not familiar with generics in Scala.

@Chandraprabu
Copy link

Using new aggregator getting below exception for Spark3. Kindly help me to resolve the issue.

22/10/23 06:28:33 ERROR TaskSetManager: Task 3 in stage 0.0 failed 4 times; aborting job
Exception in thread "main" org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 0.0 failed 4 times, most recent failure: Lost task 3.3 in stage 0.0 (TID 22, 10.233.101.132, executor 1): java.lang.RuntimeException: Error while encoding: java.lang.RuntimeException: Cannot use null as map key!
externalmaptocatalyst(lambdavariable(ExternalMapToCatalyst_key, ObjectType(class java.lang.String), true, -1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, lambdavariable(ExternalMapToCatalyst_key, ObjectType(class java.lang.String), true, -1), true, false), lambdavariable(ExternalMapToCatalyst_value, LongType, false, -2), lambdavariable(ExternalMapToCatalyst_value, LongType, false, -2), knownnotnull(assertnotnull(input[0, com.verizon.ZeroEsDataParser.utils.FrenquencyMap, true])).frequencyMap) AS frequencyMap#2320
at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:215)
at org.apache.spark.sql.execution.aggregate.ScalaAggregator.serialize(udaf.scala:509)
at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.serializeAggregateBufferInPlace(interfaces.scala:591)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.$anonfun$dumpToExternalSorter$3(ObjectAggregationMap.scala:89)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.$anonfun$dumpToExternalSorter$3$adapted(ObjectAggregationMap.scala:87)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.dumpToExternalSorter(ObjectAggregationMap.scala:87)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:178)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.(ObjectAggregationIterator.scala:78)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2(ObjectHashAggregateExec.scala:129)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2$adapted(ObjectHashAggregateExec.scala:107)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:859)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:859)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
at org.apache.spark.scheduler.Task.run(Task.scala:127)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:444)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:447)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:750)

@kennyishihara
Copy link

Thanks @ronenlh for that code. Following up on @Chandraprabu, depending on the use case, we can get rid of the null or NullPointerException by treating null as a string like

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

case class FrequencyMap(frequencyMap: Map[String, Long])

object MostCommonValue extends Aggregator[String, FrequencyMap, String] {

  def zero: FrequencyMap = FrequencyMap(Map[String, Long]())

  def reduce(buffer: FrequencyMap, input: String): FrequencyMap = {
    input match {
      // if the input is null, create a key "null". This will create a key "null". Adjust this part if you don't want null to be the most frequent candidate. 
      case null => buffer.frequencyMap += (
        if (buffer.frequencyMap.contains("null"))
          "null" -> (buffer.frequencyMap("null") + 1)
        else
          "null" -> 1L
        )
      case _ => buffer.frequencyMap += (
        if (buffer.frequencyMap.contains(input))
          input -> (buffer.frequencyMap(input) + 1)
        else
          input -> 1L
        )
    }

    buffer
  }

  def merge(b1: FrequencyMap, b2: FrequencyMap): FrequencyMap = {
    b1.frequencyMap ++= b2.frequencyMap.map{ case (k,v) => k -> (v + b1.frequencyMap.getOrElse(k, 0L)) }
    b1
  }

  def finish(buffer: FrequencyMap): String = buffer.frequencyMap.maxBy(_._2)._1

  def bufferEncoder: Encoder[FrequencyMap] = Encoders.product
  def outputEncoder: Encoder[String] = Encoders.STRING
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment