Skip to content

Instantly share code, notes, and snippets.

@andrearota
Created October 18, 2016 08:40
Show Gist options
  • Star 17 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save andrearota/5910b5c5ac65845f23856b2415474c38 to your computer and use it in GitHub Desktop.
Save andrearota/5910b5c5ac65845f23856b2415474c38 to your computer and use it in GitHub Desktop.
Creating Spark UDF with extra parameters via currying
// Problem: creating a Spark UDF that take extra parameter at invocation time.
// Solution: using currying
// http://stackoverflow.com/questions/35546576/how-can-i-pass-extra-parameters-to-udfs-in-sparksql
// We want to create hideTabooValues, a Spark UDF that set to -1 fields that contains any of given taboo values.
// E.g. forbiddenValues = [1, 2, 3]
// dataframe = [1, 2, 3, 4, 5, 6]
// dataframe.select(hideTabooValues(forbiddenValues)) :> [-1, -1, -1, 4, 5, 6]
//
// Implementing this in Spark, we find two major issues:
// 1) Spark UDF factories do not support parameter types other than Columns
// 2) While we can define the UDF behaviour, we are not able to tell the taboo list content before actual invocation.
//
// To overcome these limitations, we need to exploit Scala functional programming capabilities, using currying.
import org.apache.spark.sql._
import org.apache.spark.sql.types._
// Just create a simple dataframe with integers from 0 to 999.
val rowRDD = sc.parallelize(0 to 999).map(Row(_))
val schema = StructType(StructField("value", IntegerType, true) :: Nil)
val rowDF = sqlContext.createDataFrame(rowRDD, schema)
// Here we use currying: hideTabooValues is a partial function of type (List[Int]) => UserDefinedFunction
def hideTabooValues(taboo: List[Int]) = udf((n: Int) => if (taboo.contains(n)) -1 else n)
// Semplifying, you can see hideTabooValues as a UDF factory, that specialises the given UDF definition at invocation time.
// This will show that, without giving a parameter, hideTabooValues is just a function.
hideTabooValues _
// res7: List[Int] => org.apache.spark.sql.UserDefinedFunction = <function1>
// It's time to try our UDF! Let's define the taboo list
val forbiddenValues = List(0, 1, 2)
// And then use Spark SQL to apply the UDF. You can see two invocation here: the first creates the specific UDF
// with the given taboo list, and the second uses the UDF itself in a classic select instruction.
rowDF.select(hideTabooValues(forbiddenValues)(rowDF("value"))).show(6)
// +----------+
// |UDF(value)|
// +----------+
// | -1|
// | -1|
// | -1|
// | 3|
// | 4|
// | 5|
// +----------+
// only showing top 6 rows
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment