Skip to content

Instantly share code, notes, and snippets.

@cloud-fan
Last active February 17, 2021 09:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cloud-fan/f88baf770fa0c6f9ad312e8c92ff6c21 to your computer and use it in GitHub Desktop.
Save cloud-fan/f88baf770fa0c6f9ad312e8c92ff6c21 to your computer and use it in GitHub Desktop.
object NewUDFBenchmark extends SqlBasedBenchmark {
import spark.implicits._
private def nativeAdd(card: Long): Unit = {
spark.range(card).select($"id" + $"id").write.format("noop").mode("append").save()
}
private def udfAdd(card: Long): Unit = {
val my_udf = udf { (input1: Long, input2: Long) => input1 + input2 }
spark.range(card).select(my_udf($"id", $"id")).write.format("noop").mode("append").save()
}
private def newUdfAdd(card: Long): Unit = {
val obj = Literal(new NewUDF, ObjectType(classOf[NewUDF]))
val ink = Invoke(obj, "call", LongType,
Seq(UnresolvedAttribute("id"), UnresolvedAttribute("id")),
propagateNull = false, returnNullable = false)
spark.range(card).select(Column(ink)).write.format("noop").mode("append").save()
}
private def newRowUdfAdd(card: Long): Unit = {
val obj = Literal(new NewRowUDF, ObjectType(classOf[NewRowUDF]))
val ink = Invoke(obj, "call", LongType,
Seq(UnresolvedAttribute("id"), UnresolvedAttribute("id")),
propagateNull = false, returnNullable = false)
spark.range(card).select(Column(ink)).write.format("noop").mode("append").save()
}
private def staticMethod(card: Long): Unit = {
val ink = StaticInvoke(StaticUDF.getClass, LongType, "call",
Seq(UnresolvedAttribute("id"), UnresolvedAttribute("id")),
propagateNull = false, returnNullable = false)
spark.range(card).select(Column(ink)).write.format("noop").mode("append").save()
}
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
val N = 1000L * 1000 * 1000
val benchmark = new Benchmark("UDF perf", N, output = output)
benchmark.addCase("native add", 3) { _ =>
nativeAdd(N)
}
benchmark.addCase("udf add", 3) { _ =>
udfAdd(N)
}
benchmark.addCase("new udf add", 3) { _ =>
newUdfAdd(N)
}
benchmark.addCase("new row udf add", 3) { _ =>
newRowUdfAdd(N)
}
benchmark.addCase("static udf add", 3) { _ =>
staticMethod(N)
}
benchmark.run()
}
}
class NewUDF extends Serializable {
def call(input1: Long, input2: Long): Long = input1 + input2
}
class NewRowUDF extends Serializable {
private val inputRow = new SpecificInternalRow(Seq(LongType, LongType))
def call(input1: Long, input2: Long): Long = {
inputRow.setLong(0, input1)
inputRow.setLong(1, input2)
inputRow.getLong(0) + inputRow.getLong(1)
}
}
object StaticUDF extends Serializable {
def call(input1: Long, input2: Long): Long = input1 + input2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment