Skip to content

Instantly share code, notes, and snippets.

@lovasoa
Created February 15, 2018 13:43
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save lovasoa/0c3a180b15d169cf3d2d4bccacbdc620 to your computer and use it in GitHub Desktop.
Save lovasoa/0c3a180b15d169cf3d2d4bccacbdc620 to your computer and use it in GitHub Desktop.
spark UDAF for computing the mean of vectors
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class VectorMean extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", VectorType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(Seq(
StructField("count", LongType),
StructField("sum", VectorType)
))
// This is the output type of your aggregatation function.
override def dataType: DataType = VectorType
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = Vectors.zeros(0)
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val oldSum = buffer.getAs[Long](0)
val oldVec = buffer.getAs[Vector](1)
val inputVec = input.getAs[Vector](0)
buffer(0) = oldSum + 1
buffer(1) = vectorSum(oldVec, inputVec)
}
def vectorSum(a: Vector, b: Vector): Vector = {
val aa = if (a.size > 0) a else Vectors.zeros(b.size)
val bb = if (b.size > 0) b else Vectors.zeros(a.size)
Vectors.dense((aa.toArray, b.toArray).zipped.map(_ + _))
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1) = vectorSum(buffer1.getAs[Vector](1), buffer2.getAs[Vector](1))
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
val count = buffer.getLong(0)
val vec = buffer.getAs[Vector](1)
Vectors.dense(vec.toArray.map(_ / count))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment