Created
November 29, 2017 21:39
-
-
Save mrchristine/4f77885dab668a39c063b44b4ae71582 to your computer and use it in GitHub Desktop.
Spark UDAF to sum vectors for common keys
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.databricks.example.pivot | |
/** | |
This code allows a user to add vectors together for common keys. | |
The code in the comments show you how to register the scala UDAF to be called from pyspark. | |
The UDAF can only be called from a SQL expression (aka spark.sql() or df.expr() ) | |
**/ | |
/** | |
# Python code to register a scala UDAF | |
scala_sql_context = sqlContext._ssql_ctx | |
scala_spark_context = sqlContext._sc | |
scala_spark_context._jvm.com.databricks.example.pivot.VectorSumUDAF.registerUdf(scala_sql_context) | |
pivot_df = spark.sql("SELECT id, VECTORSUM(a_vector) AS cv FROM test_table GROUP BY id") | |
display(pivot_df) | |
**/ | |
import org.apache.spark.sql.expressions.MutableAggregationBuffer | |
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.ml.linalg.{Vector, Vectors, SQLDataTypes} | |
import org.apache.spark.ml.linalg.SparseVector | |
import org.apache.spark.sql.SQLContext | |
// VectorSum is used to aggregate the vectors to create the pivot table. | |
object VectorSumUDAF { | |
private class VectorSum extends UserDefinedAggregateFunction { | |
// function to add array values. agg is the current value, and arr is the incoming record | |
private def addArray(agg: Array[Double], arr: Array[Double]) { | |
var i = 0 | |
while(i < arr.length) { | |
agg(i) = agg(i) + arr(i) | |
i += 1 | |
} | |
} | |
// function to determine if the current array size is large enough to old the record size. | |
// if not, it will resize the array and return a new copy | |
private def ensureArraySize(agg: Array[Double], size: Int): Array[Double] = { | |
if(size > agg.length) { | |
val newAgg = new Array[Double](size) | |
Array.copy(agg, 0, newAgg, 0, agg.length) | |
newAgg | |
} else { | |
agg | |
} | |
} | |
// Schema you get as an input | |
def inputSchema = new StructType().add("vec", SQLDataTypes.VectorType) | |
// Schema of the row which is used for aggregation | |
def bufferSchema = new StructType().add("arr", ArrayType(DoubleType, false)) | |
// Returned type | |
def dataType = SQLDataTypes.VectorType | |
// Self-explaining | |
def deterministic = true | |
// zero value | |
def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, Array[Double]()) | |
// take input row and cast as vector, add to the current value of the buffer, then update buffer. | |
def update(buffer: MutableAggregationBuffer, input: Row) = { | |
if(!input.isNullAt(0)) { | |
val vec = input.getAs[Vector](0) | |
val arr: Array[Double] = vec.toArray | |
val agg: Array[Double] = ensureArraySize(buffer.getSeq[Double](0).toArray, arr.length) | |
addArray(agg, arr) | |
buffer.update(0, agg.toSeq) | |
} | |
} | |
// merge multiple aggregates together, then update the output buffer. | |
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { | |
val agg2: Array[Double] = buffer2.getSeq[Double](0).toArray | |
val agg1: Array[Double] = ensureArraySize(buffer1.getSeq[Double](0).toArray, agg2.length) | |
addArray(agg1, agg2) | |
buffer1.update(0, agg1.toSeq) | |
} | |
// Called on exit to get return value as a sparse vector | |
def evaluate(buffer: Row) = Vectors.dense(buffer.getSeq[Double](0).toArray).compressed | |
} | |
// This function is called from PySpark to register our UDAF | |
def registerUdf(sqlCtx: SQLContext) { | |
sqlCtx.udf.register("VECTORSUM", new VectorSum) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment