Skip to content

Instantly share code, notes, and snippets.

@josep2
Last active February 26, 2020 07:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save josep2/6f25d32546b2decea95af3d9e16c4c22 to your computer and use it in GitHub Desktop.
Save josep2/6f25d32546b2decea95af3d9e16c4c22 to your computer and use it in GitHub Desktop.
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class HarmonicMean extends UserDefinedAggregateFunction {
// Defind the schema of the input data
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", DoubleType) :: Nil)
// Define how the aggregates types will be
override def bufferSchema: StructType = StructType(
StructField("count", LongType) :: Nil,
StructField("product", DoubleType) :: Nil
)
// define the return type
override def dataType: DataType = DoubleType
// Does the function return the same value for the same input?
override def deterministic: Boolean = true
// Initial values
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0.0
}
// Updated based on Input
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Long](0) + 1
buffer(1) = buffer.getAs[Double](1) + (1.toDouble / input.getAs[Double](0))
}
// Merge two schemas
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}
// Output
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getDouble(1)
}
}
@Nithanaroy
Copy link

bufferSchema should be,

// Define how the aggregates types will be
  override def bufferSchema: StructType = StructType(
    StructField("count", LongType) :: 
    StructField("product", DoubleType) :: Nil
  )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment