Last active
February 17, 2021 09:50
-
-
Save cloud-fan/f88baf770fa0c6f9ad312e8c92ff6c21 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object NewUDFBenchmark extends SqlBasedBenchmark { | |
import spark.implicits._ | |
private def nativeAdd(card: Long): Unit = { | |
spark.range(card).select($"id" + $"id").write.format("noop").mode("append").save() | |
} | |
private def udfAdd(card: Long): Unit = { | |
val my_udf = udf { (input1: Long, input2: Long) => input1 + input2 } | |
spark.range(card).select(my_udf($"id", $"id")).write.format("noop").mode("append").save() | |
} | |
private def newUdfAdd(card: Long): Unit = { | |
val obj = Literal(new NewUDF, ObjectType(classOf[NewUDF])) | |
val ink = Invoke(obj, "call", LongType, | |
Seq(UnresolvedAttribute("id"), UnresolvedAttribute("id")), | |
propagateNull = false, returnNullable = false) | |
spark.range(card).select(Column(ink)).write.format("noop").mode("append").save() | |
} | |
private def newRowUdfAdd(card: Long): Unit = { | |
val obj = Literal(new NewRowUDF, ObjectType(classOf[NewRowUDF])) | |
val ink = Invoke(obj, "call", LongType, | |
Seq(UnresolvedAttribute("id"), UnresolvedAttribute("id")), | |
propagateNull = false, returnNullable = false) | |
spark.range(card).select(Column(ink)).write.format("noop").mode("append").save() | |
} | |
private def staticMethod(card: Long): Unit = { | |
val ink = StaticInvoke(StaticUDF.getClass, LongType, "call", | |
Seq(UnresolvedAttribute("id"), UnresolvedAttribute("id")), | |
propagateNull = false, returnNullable = false) | |
spark.range(card).select(Column(ink)).write.format("noop").mode("append").save() | |
} | |
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
val N = 1000L * 1000 * 1000 | |
val benchmark = new Benchmark("UDF perf", N, output = output) | |
benchmark.addCase("native add", 3) { _ => | |
nativeAdd(N) | |
} | |
benchmark.addCase("udf add", 3) { _ => | |
udfAdd(N) | |
} | |
benchmark.addCase("new udf add", 3) { _ => | |
newUdfAdd(N) | |
} | |
benchmark.addCase("new row udf add", 3) { _ => | |
newRowUdfAdd(N) | |
} | |
benchmark.addCase("static udf add", 3) { _ => | |
staticMethod(N) | |
} | |
benchmark.run() | |
} | |
} | |
class NewUDF extends Serializable { | |
def call(input1: Long, input2: Long): Long = input1 + input2 | |
} | |
class NewRowUDF extends Serializable { | |
private val inputRow = new SpecificInternalRow(Seq(LongType, LongType)) | |
def call(input1: Long, input2: Long): Long = { | |
inputRow.setLong(0, input1) | |
inputRow.setLong(1, input2) | |
inputRow.getLong(0) + inputRow.getLong(1) | |
} | |
} | |
object StaticUDF extends Serializable { | |
def call(input1: Long, input2: Long): Long = input1 + input2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment