-
-
Save anish749/6a815ed281f538068a0d3a20ca9044fa to your computer and use it in GitHub Desktop.
package org.anish.spark.mostcommonvalue | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} | |
import org.apache.spark.sql.types._ | |
import scalaz.Scalaz._ | |
/** | |
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to | |
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both | |
* these values together is considered as mode. | |
* | |
* Usage: | |
* | |
* DataFrame / DataSet DSL | |
* val mostCommonValue = new MostCommonValue | |
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city"))) | |
* | |
* Spark SQL: | |
* sqlContext.udf.register("mode", new MostCommonValue) | |
* %sql | |
* -- Use a group_by statement and call the UDAF. | |
* select group_id, mode(id) from table group by group_id | |
* | |
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html | |
* | |
* Created by anish on 26/05/17. | |
*/ | |
class MostCommonValue extends UserDefinedAggregateFunction { | |
// This is the input fields for your aggregate function. | |
// We use StringType, because Mode can also be meaningfully applied on nominal data | |
override def inputSchema: StructType = | |
StructType(StructField("value", StringType) :: Nil) | |
// This is the internal fields you keep for computing your aggregate. | |
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap | |
override def bufferSchema: StructType = StructType( | |
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil | |
) | |
// This is the output type of your aggregation function. | |
override def dataType: DataType = StringType | |
override def deterministic: Boolean = true | |
// This is the initial value for the buffer schema. | |
override def initialize(buffer: MutableAggregationBuffer): Unit = { | |
buffer(0) = Map[String, Long]() | |
} | |
// This is how to update your buffer schema given an input. | |
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { | |
buffer(0) = buffer.getAs[Map[String, Long]](0) |+| Map(input.getAs[String](0) -> 1L) | |
} | |
// This is how you merge two objects with the bufferSchema type. | |
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { | |
buffer1(0) = buffer1.getAs[Map[String, Long]](0) |+| buffer2.getAs[Map[String, Long]](0) | |
} | |
// This is where you output the final value, given the final value of your bufferSchema. | |
override def evaluate(buffer: Row): String = { | |
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1 | |
} | |
} |
package org.anish.spark.mostcommonvalue | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} | |
import org.apache.spark.sql.types._ | |
/** | |
* Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to | |
* Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both | |
* these values together is considered as mode. | |
* | |
* Usage: | |
* | |
* DataFrame / DataSet DSL | |
* val mostCommonValue = new MostCommonValue | |
* df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city"))) | |
* | |
* Spark SQL: | |
* sqlContext.udf.register("mode", new MostCommonValue) | |
* %sql | |
* -- Use a group_by statement and call the UDAF. | |
* select group_id, mode(id) from table group by group_id | |
* | |
* Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html | |
* | |
* This version doesn't use the ScalaZ library | |
* | |
* Created by anish on 26/05/17. | |
*/ | |
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction { | |
// This is the input fields for your aggregate function. | |
// We use StringType, because Mode can also be meaningfully applied on nominal data | |
override def inputSchema: StructType = | |
StructType(StructField("value", StringType) :: Nil) | |
// This is the internal fields you keep for computing your aggregate. | |
// We store the frequency of all the distinct element we encounter for the given attribute in this HashMap | |
override def bufferSchema: StructType = StructType( | |
StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil | |
) | |
// This is the output type of your aggregation function. | |
override def dataType: DataType = StringType | |
override def deterministic: Boolean = true | |
// This is the initial value for the buffer schema. | |
override def initialize(buffer: MutableAggregationBuffer): Unit = { | |
buffer(0) = Map[String, Long]() | |
} | |
// This is how to update your buffer schema given an input. | |
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { | |
val inpString = input.getAs[String](0) | |
val existingMap = buffer.getAs[Map[String, Long]](0) | |
buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L) | |
} | |
// This is how you merge two objects with the bufferSchema type. | |
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { | |
val map1 = buffer1.getAs[Map[String, Long]](0) | |
val map2 = buffer2.getAs[Map[String, Long]](0) | |
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k,0L)) } | |
} | |
// This is where you output the final value, given the final value of your bufferSchema. | |
override def evaluate(buffer: Row): String = { | |
buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1 | |
} | |
} |
Hello Anish,
This UADF is amazing and I'm actually working on a project needing the implementation of the same.
My question is there anyway to do it with Spark 1.6 and without Scalaz?
I appreciate you prompt answer :)
I added a file with an implementation without using Scalaz. Hope that helps. I didn't test this with Spark 1.6, but I believe it will with Spark 1.6 as well. Let me know if you are facing problems.
Permalink to file: https://gist.github.com/anish749/6a815ed281f538068a0d3a20ca9044fa#file-mostcommonvalue_noscalaz-scala
good stuff Anish! Wish I had stumbled on this earlier, would have saved me quite some time to write my own UDAF. Did you try extending it to make it work for multiple types?
There is a NullPointException, which is fixed by:
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
- val inpString = input.getAs[String](0)
- val existingMap = buffer.getAs[Map[String, Long]](0)
- buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L)
+ if (!input.isNullAt(0)) {
+ val inpString = input.getAs[String](0)
+ val existingMap = buffer.getAs[Map[String, Long]](0)
+ buffer(0) = existingMap + (if (existingMap.contains(inpString)) inpString -> (existingMap(inpString) + 1) else inpString -> 1L)
+ }
}
this is the new syntax, using Aggregator[-IN, BUF, OUT]
and spark.udf.register
docs: https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html
thanks
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions
case class FrenquencyMap(var frequencyMap: Map[String, Long])
object MostCommonValue extends Aggregator[String, FrenquencyMap, String] {
def zero: FrenquencyMap = FrenquencyMap(Map[String, Long]())
def reduce(buffer: FrenquencyMap, input: String): FrenquencyMap = {
buffer.frequencyMap += (
if (buffer.frequencyMap.contains(input))
input -> (buffer.frequencyMap(input) + 1)
else
input -> 1L
)
buffer
}
def merge(b1: FrenquencyMap, b2: FrenquencyMap): FrenquencyMap = {
b1.frequencyMap ++= b2.frequencyMap.map{ case (k,v) => k -> (v + b1.frequencyMap.getOrElse(k, 0L)) }
b1
}
def finish(buffer: FrenquencyMap): String = buffer.frequencyMap.maxBy(_._2)._1
def bufferEncoder: Encoder[FrenquencyMap] = Encoders.product
def outputEncoder: Encoder[String] = Encoders.STRING
}
spark.udf.register("mode", functions.udaf(MostCommonValue))
val result = spark.sql("SELECT foo, mode(bar) as most_common_value FROM baz group by foo")
this is much nicer with the new Aggregator
api, a couple of gotchas about the above implementation..
- using
var
might be dangerous for this use case. I am not sure if Spark provides guarantees around what the map will contain in case an executor fails and the task is retried. It is hence advisable to useval
and return a new object in thereduce
andmerge
fn. - The input and output types can be generic with the new api, not necessarily
String
. - An empty
String
shouldn't be filtered out, its a valid value ofString
- Mode isn't necessarily 1 value, but can be more than 1, if all have the same number of elements.
- If frequency map is empty the result shouldn't be an empty string, but the fact that mode will be undefined.
Thanks for the feedback. I'm not familiar with Scala and only wrote it this way in lieu of a PySpark API for UDAF.
This is the deprecation warning I got from your original implementation (in case others look it up):
warning: class UserDefinedAggregateFunction in package expressions is deprecated (since 3.0.0): Aggregator[IN, BUF, OUT] should now be registered as a UDF via the functions.udaf(agg) method.
class MostCommonValue_NoScalaz extends UserDefinedAggregateFunction
I had to wrap Map[String, Long]
in an object for Encoders.product
to work.
I filtered empty string as I read that it was the safest way to filter out null
. I had an issue with maxBy
on null
in some cases.
I edited the code to rename "Mode" into something else and removed the filters, leaving out the minimal implementation.
I'm not familiar with generics in Scala.
Using new aggregator getting below exception for Spark3. Kindly help me to resolve the issue.
22/10/23 06:28:33 ERROR TaskSetManager: Task 3 in stage 0.0 failed 4 times; aborting job
Exception in thread "main" org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 0.0 failed 4 times, most recent failure: Lost task 3.3 in stage 0.0 (TID 22, 10.233.101.132, executor 1): java.lang.RuntimeException: Error while encoding: java.lang.RuntimeException: Cannot use null as map key!
externalmaptocatalyst(lambdavariable(ExternalMapToCatalyst_key, ObjectType(class java.lang.String), true, -1), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, lambdavariable(ExternalMapToCatalyst_key, ObjectType(class java.lang.String), true, -1), true, false), lambdavariable(ExternalMapToCatalyst_value, LongType, false, -2), lambdavariable(ExternalMapToCatalyst_value, LongType, false, -2), knownnotnull(assertnotnull(input[0, com.verizon.ZeroEsDataParser.utils.FrenquencyMap, true])).frequencyMap) AS frequencyMap#2320
at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:215)
at org.apache.spark.sql.execution.aggregate.ScalaAggregator.serialize(udaf.scala:509)
at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.serializeAggregateBufferInPlace(interfaces.scala:591)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.$anonfun$dumpToExternalSorter$3(ObjectAggregationMap.scala:89)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.$anonfun$dumpToExternalSorter$3$adapted(ObjectAggregationMap.scala:87)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.dumpToExternalSorter(ObjectAggregationMap.scala:87)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:178)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.(ObjectAggregationIterator.scala:78)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2(ObjectHashAggregateExec.scala:129)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2$adapted(ObjectHashAggregateExec.scala:107)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:859)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:859)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
at org.apache.spark.scheduler.Task.run(Task.scala:127)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:444)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:447)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:750)
Thanks @ronenlh for that code. Following up on @Chandraprabu, depending on the use case, we can get rid of the null or NullPointerException by treating null as a string like
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
case class FrequencyMap(frequencyMap: Map[String, Long])
object MostCommonValue extends Aggregator[String, FrequencyMap, String] {
def zero: FrequencyMap = FrequencyMap(Map[String, Long]())
def reduce(buffer: FrequencyMap, input: String): FrequencyMap = {
input match {
// if the input is null, create a key "null". This will create a key "null". Adjust this part if you don't want null to be the most frequent candidate.
case null => buffer.frequencyMap += (
if (buffer.frequencyMap.contains("null"))
"null" -> (buffer.frequencyMap("null") + 1)
else
"null" -> 1L
)
case _ => buffer.frequencyMap += (
if (buffer.frequencyMap.contains(input))
input -> (buffer.frequencyMap(input) + 1)
else
input -> 1L
)
}
buffer
}
def merge(b1: FrequencyMap, b2: FrequencyMap): FrequencyMap = {
b1.frequencyMap ++= b2.frequencyMap.map{ case (k,v) => k -> (v + b1.frequencyMap.getOrElse(k, 0L)) }
b1
}
def finish(buffer: FrequencyMap): String = buffer.frequencyMap.maxBy(_._2)._1
def bufferEncoder: Encoder[FrequencyMap] = Encoders.product
def outputEncoder: Encoder[String] = Encoders.STRING
}
Is there a way to do this without Scalaz?