Skip to content

Instantly share code, notes, and snippets.

@travishegner
Created April 22, 2019 19:16
Show Gist options
  • Save travishegner/33b5af41371eb1adf6f78556aaa48e3b to your computer and use it in GitHub Desktop.
Save travishegner/33b5af41371eb1adf6f78556aaa48e3b to your computer and use it in GitHub Desktop.
User Defined Aggregate Function: Vector Sum
package com.trilliumstaffing.hadoop.tools.udaf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{Vector, Vectors}
import breeze.linalg.{Vector => BV}
import org.apache.spark.sql.Row
class VectorSum extends UserDefinedAggregateFunction {
def dataType: DataType = VectorType
def deterministic: Boolean = true
def inputSchema: StructType = StructType(Array(StructField("value", VectorType)))
def bufferSchema: StructType = StructType(Array(StructField("sum", VectorType)))
def update(b: MutableAggregationBuffer, r: Row): Unit = {
val ag = Option(b.get(0).asInstanceOf[Vector])
val vl = Option(r.get(0).asInstanceOf[Vector])
b(0) = (ag, vl) match {
case (None, None) => null
case (Some(a), None) => a
case (None, Some(v)) => v
case (Some(a), Some(v)) => Vectors.dense((BV(a.toArray) + BV(v.toArray)).toArray)
}
}
def merge(b1: MutableAggregationBuffer, b2: Row): Unit = {
val ag1 = Option(b1.get(0).asInstanceOf[Vector])
val ag2 = Option(b2.get(0).asInstanceOf[Vector])
b1(0) = (ag1, ag2) match {
case (None, None) => null
case (Some(a1), None) => a1
case (None, Some(a2)) => a2
case (Some(a1), Some(a2)) => Vectors.dense((BV(a1.toArray) + BV(a2.toArray)).toArray)
}
}
def initialize(b: MutableAggregationBuffer): Unit = {
b(0) = null
}
def evaluate(b: Row): Vector = b.get(0).asInstanceOf[Vector]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment