Skip to content

Instantly share code, notes, and snippets.

@mrchristine
Created November 29, 2017 21:39
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mrchristine/4f77885dab668a39c063b44b4ae71582 to your computer and use it in GitHub Desktop.
Save mrchristine/4f77885dab668a39c063b44b4ae71582 to your computer and use it in GitHub Desktop.
Spark UDAF to sum vectors for common keys
package com.databricks.example.pivot
/**
This code allows a user to add vectors together for common keys.
The code in the comments show you how to register the scala UDAF to be called from pyspark.
The UDAF can only be called from a SQL expression (aka spark.sql() or df.expr() )
**/
/**
# Python code to register a scala UDAF
scala_sql_context = sqlContext._ssql_ctx
scala_spark_context = sqlContext._sc
scala_spark_context._jvm.com.databricks.example.pivot.VectorSumUDAF.registerUdf(scala_sql_context)
pivot_df = spark.sql("SELECT id, VECTORSUM(a_vector) AS cv FROM test_table GROUP BY id")
display(pivot_df)
**/
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.ml.linalg.{Vector, Vectors, SQLDataTypes}
import org.apache.spark.ml.linalg.SparseVector
import org.apache.spark.sql.SQLContext
// VectorSum is used to aggregate the vectors to create the pivot table.
object VectorSumUDAF {
private class VectorSum extends UserDefinedAggregateFunction {
// function to add array values. agg is the current value, and arr is the incoming record
private def addArray(agg: Array[Double], arr: Array[Double]) {
var i = 0
while(i < arr.length) {
agg(i) = agg(i) + arr(i)
i += 1
}
}
// function to determine if the current array size is large enough to old the record size.
// if not, it will resize the array and return a new copy
private def ensureArraySize(agg: Array[Double], size: Int): Array[Double] = {
if(size > agg.length) {
val newAgg = new Array[Double](size)
Array.copy(agg, 0, newAgg, 0, agg.length)
newAgg
} else {
agg
}
}
// Schema you get as an input
def inputSchema = new StructType().add("vec", SQLDataTypes.VectorType)
// Schema of the row which is used for aggregation
def bufferSchema = new StructType().add("arr", ArrayType(DoubleType, false))
// Returned type
def dataType = SQLDataTypes.VectorType
// Self-explaining
def deterministic = true
// zero value
def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, Array[Double]())
// take input row and cast as vector, add to the current value of the buffer, then update buffer.
def update(buffer: MutableAggregationBuffer, input: Row) = {
if(!input.isNullAt(0)) {
val vec = input.getAs[Vector](0)
val arr: Array[Double] = vec.toArray
val agg: Array[Double] = ensureArraySize(buffer.getSeq[Double](0).toArray, arr.length)
addArray(agg, arr)
buffer.update(0, agg.toSeq)
}
}
// merge multiple aggregates together, then update the output buffer.
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val agg2: Array[Double] = buffer2.getSeq[Double](0).toArray
val agg1: Array[Double] = ensureArraySize(buffer1.getSeq[Double](0).toArray, agg2.length)
addArray(agg1, agg2)
buffer1.update(0, agg1.toSeq)
}
// Called on exit to get return value as a sparse vector
def evaluate(buffer: Row) = Vectors.dense(buffer.getSeq[Double](0).toArray).compressed
}
// This function is called from PySpark to register our UDAF
def registerUdf(sqlCtx: SQLContext) {
sqlCtx.udf.register("VECTORSUM", new VectorSum)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment