Skip to content

Instantly share code, notes, and snippets.

@erikerlandson
Last active February 25, 2021 01:00
Show Gist options
  • Save erikerlandson/3c4d8c6345d1521d89e0d894a423046f to your computer and use it in GitHub Desktop.
Save erikerlandson/3c4d8c6345d1521d89e0d894a423046f to your computer and use it in GitHub Desktop.
package org.apache.spark.countSerDe
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
@SQLUserDefinedType(udt = classOf[CountSerDeUDT])
case class CountSerDeSQL(nSer: Int, nDeSer: Int)
class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] {
def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL]
override def typeName: String = "count-ser-de"
private[spark] override def asNullable: CountSerDeUDT = this
def sqlType: DataType = StructType(
StructField("nSer", IntegerType, false) ::
StructField("nDeSer", IntegerType, false) ::
Nil)
def serialize(sql: CountSerDeSQL): Any = {
val row = new GenericInternalRow(2)
row.setInt(0, 1 + sql.nSer)
row.setInt(1, sql.nDeSer)
row
}
def deserialize(any: Any): CountSerDeSQL = any match {
case row: InternalRow if (row.numFields == 2) =>
CountSerDeSQL(row.getInt(0), 1 + row.getInt(1))
case u => throw new Exception(s"failed to deserialize: $u")
}
override def equals(obj: Any): Boolean = {
obj match {
case _: CountSerDeUDT => true
case _ => false
}
}
override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode()
}
case object CountSerDeUDT extends CountSerDeUDT
case object CountSerDeUDAF extends UserDefinedAggregateFunction {
def deterministic: Boolean = true
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil)
def bufferSchema: StructType = StructType(StructField("count-ser-de", CountSerDeUDT) :: Nil)
def dataType: DataType = CountSerDeUDT
def initialize(buf: MutableAggregationBuffer): Unit = {
buf(0) = CountSerDeSQL(0, 0)
}
def update(buf: MutableAggregationBuffer, input: Row): Unit = {
val sql = buf.getAs[CountSerDeSQL](0)
buf(0) = sql
}
def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = {
val sql1 = buf1.getAs[CountSerDeSQL](0)
val sql2 = buf2.getAs[CountSerDeSQL](0)
buf1(0) = CountSerDeSQL(sql1.nSer + sql2.nSer, sql1.nDeSer + sql2.nDeSer)
}
def evaluate(buf: Row): Any = buf.getAs[CountSerDeSQL](0)
}
scala> import scala.util.Random.nextGaussian, org.apache.spark.countSerDe._
import scala.util.Random.nextGaussian
import org.apache.spark.countSerDe._
scala> val data = sc.parallelize(Vector.fill(1000){(nextGaussian, nextGaussian)}).toDF.as[(Double, Double)]
data: org.apache.spark.sql.Dataset[(Double, Double)] = [_1: double, _2: double]
scala> val udaf = CountSerDeUDAF
udaf: org.apache.spark.countSerDe.CountSerDeUDAF.type = CountSerDeUDAF
scala> val agg = data.agg(udaf($"_1"))
agg: org.apache.spark.sql.DataFrame = [countserdeudaf$(_1): count-ser-de]
scala> agg.first.getAs[CountSerDeSQL](0)
res4: org.apache.spark.countSerDe.CountSerDeSQL = CountSerDeSQL(1006,1006)
scala> spark.udf.register("countserde", udaf)
res1: org.apache.spark.sql.expressions.UserDefinedAggregateFunction = CountSerDeUDAF
scala> val agg = data.agg(expr("countserde(_1)"))
agg: org.apache.spark.sql.DataFrame = [countserde(_1): count-ser-de]
scala> agg.first.getAs[CountSerDeSQL](0)
res2: org.apache.spark.countSerDe.CountSerDeSQL = CountSerDeSQL(1006,1006)
scala>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment