Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Examples of Mean Udaf using `UserDefinedAggregateFunction` and `Aggregator`
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.{Aggregator, UserDefinedFunction}
import org.apache.spark.sql.functions._
case class AggregatorState(sum: Long, count: Long)
// Aggregator[IN, BUF, OUT]
val meanAggregator = new Aggregator[Long, AggregatorState, Double]() {
// Initialize your buffer
def zero: AggregatorState = AggregatorState(0L, 0L)
// This is how to update your buffer given an input
def reduce(b: AggregatorState, a: Long): AggregatorState = AggregatorState(b.sum + a, b.count + 1)
// This is how to merge two buffers
def merge(b1: AggregatorState, b2: AggregatorState): AggregatorState = AggregatorState(b1.sum + b2.sum, b1.count + b2.count)
// This is where you output the final value, given the final value of your buffer
def finish(reduction: AggregatorState): Double = reduction.sum / reduction.count
// Used to encode your buffer
def bufferEncoder: Encoder[AggregatorState] = implicitly(ExpressionEncoder[AggregatorState])
// Used to encode your output
def outputEncoder: Encoder[Double] = implicitly(ExpressionEncoder[Double])
}
val meanUdaf: UserDefinedFunction = udaf(meanAggregator)
spark.range(100).withColumn("group", col("id")%2).groupBy("group").agg(meanUdaf(col("id")).as("mean")).show
/**
+-----+----+
|group|mean|
+-----+----+
| 0|49.0|
| 1|50.0|
+-----+----+
*/
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
object MeanUdaf extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("id", LongType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
def bufferSchema: StructType = StructType(
StructField("sum", LongType) ::
StructField("count", LongType) :: Nil
)
// This is the output type of your aggregatation function.
def dataType: DataType = DoubleType
def deterministic: Boolean = true
// This is the initial value for your buffer schema.
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// This is how to update your buffer schema given an input.
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1L
}
// This is how to merge two objects with the bufferSchema type.
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// This is where you output the final value, given the final value of your bufferSchema.
def evaluate(buffer: Row): Any =
(buffer.getLong(0) / buffer.getLong(1)).toDouble
}
spark.range(100).withColumn("group", $"id"%2).groupBy("group").agg(MeanUdaf($"id").as("mean")).show
/**
+-----+----+
|group|mean|
+-----+----+
| 0|49.0|
| 1|50.0|
+-----+----+
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment