Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phstudy/b244ddc0a11cee9a2c990fc9eb8fdb95 to your computer and use it in GitHub Desktop.
Save phstudy/b244ddc0a11cee9a2c990fc9eb8fdb95 to your computer and use it in GitHub Desktop.
Experimenting with Spark SQL UDAF - HyperLogLog UDAF for distinct counts, that stores the actual HLL for each row to allow further aggregation
class HyperLogLogStoreUDAF extends UserDefinedAggregateFunction {
override def inputSchema = new StructType()
.add("stringInput", BinaryType)
override def update(buffer: MutableAggregationBuffer, input: Row) = {
// This input Row only has a single column storing the input value in String (or other Binary data).
// We only update the buffer when the input value is not null.
if (!input.isNullAt(0)) {
if (buffer.isNullAt(0)) {
val newHLL = new HyperLogLog(0.05)
newHLL.offer(input.get(0))
buffer.update(0, newHLL)
}
else {
val updated = buffer.get(0).asInstanceOf[HyperLogLog]
updated.offer(input.get(0))
buffer.update(0, updated)
}
}
}
override def bufferSchema = new StructType().add("hll", MyHyperLogLogUDT)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
// buffer1 and buffer2 have the same structure.
// We only update the buffer1 when the input buffer2's sum value is not null.
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
val hll = buffer2.get(0).asInstanceOf[HyperLogLog]
buffer1.update(0, hll)
}
else {
val left = buffer1.get(0).asInstanceOf[HyperLogLog]
val right = buffer2.get(0).asInstanceOf[HyperLogLog]
left.addAll(right)
buffer1.update(0, left)
}
}
}
override def initialize(buffer: MutableAggregationBuffer) = {
// The initial value of the sum is null.
buffer.update(0, null)
}
override def deterministic = true
override def evaluate(buffer: Row) = {
if (buffer.isNullAt(0)) {
null
}
else {
val hll = buffer.getAs[HyperLogLog](0)
InternalRow(hll.cardinality(), hll.getBytes)
}
}
override def dataType = new StructType()
.add("cardinality", LongType)
.add("hll", MyHyperLogLogUDT)
}
// copy-and-paste of internal Spark HyperLogLogUDT because it is [private] sql
case object MyHyperLogLogUDT extends UserDefinedType[HyperLogLog] {
override def sqlType: DataType = BinaryType
/** Since we are using HyperLogLog internally, usually it will not be called. */
override def serialize(obj: Any): Array[Byte] =
obj.asInstanceOf[HyperLogLog].getBytes
/** Since we are using HyperLogLog internally, usually it will not be called. */
override def deserialize(datum: Any): HyperLogLog =
HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]])
override def userClass: Class[HyperLogLog] = classOf[HyperLogLog]
}
object TestHLL extends App {
val conf = new SparkConf()
.setMaster("local[4]")
.setAppName("test")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
sqlContext.udf.register("hllcount", new HyperLogLogStoreUDAF)
val data = sc.parallelize(Seq("a", "b", "c", "d", "a", "b"), numSlices = 2).toDF("col1")
data.registerTempTable("test")
val res = sqlContext.sql("select hllcount(col1) from test")
println(res.show())
}
/*
+--------------------+
| _c0|
+--------------------+
|[4,com.clearsprin...|
+--------------------+
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment