-
-
Save reynoldsm88/64d327473eb0207db99e61c9d4bed57c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.countSerDe | |
import org.apache.spark.sql.catalyst.util._ | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow | |
import org.apache.spark.sql.expressions.MutableAggregationBuffer | |
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction | |
@SQLUserDefinedType(udt = classOf[CountSerDeUDT]) | |
case class CountSerDeSQL(nSer: Int, nDeSer: Int) | |
class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] { | |
def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL] | |
override def typeName: String = "count-ser-de" | |
private[spark] override def asNullable: CountSerDeUDT = this | |
def sqlType: DataType = StructType( | |
StructField("nSer", IntegerType, false) :: | |
StructField("nDeSer", IntegerType, false) :: | |
Nil) | |
def serialize(sql: CountSerDeSQL): Any = { | |
val row = new GenericInternalRow(2) | |
row.setInt(0, 1 + sql.nSer) | |
row.setInt(1, sql.nDeSer) | |
row | |
} | |
def deserialize(any: Any): CountSerDeSQL = any match { | |
case row: InternalRow if (row.numFields == 2) => | |
CountSerDeSQL(row.getInt(0), 1 + row.getInt(1)) | |
case u => throw new Exception(s"failed to deserialize: $u") | |
} | |
override def equals(obj: Any): Boolean = { | |
obj match { | |
case _: CountSerDeUDT => true | |
case _ => false | |
} | |
} | |
override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode() | |
} | |
case object CountSerDeUDT extends CountSerDeUDT | |
case object CountSerDeUDAF extends UserDefinedAggregateFunction { | |
def deterministic: Boolean = true | |
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil) | |
def bufferSchema: StructType = StructType(StructField("count-ser-de", CountSerDeUDT) :: Nil) | |
def dataType: DataType = CountSerDeUDT | |
def initialize(buf: MutableAggregationBuffer): Unit = { | |
buf(0) = CountSerDeSQL(0, 0) | |
} | |
def update(buf: MutableAggregationBuffer, input: Row): Unit = { | |
val sql = buf.getAs[CountSerDeSQL](0) | |
buf(0) = sql | |
} | |
def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = { | |
val sql1 = buf1.getAs[CountSerDeSQL](0) | |
val sql2 = buf2.getAs[CountSerDeSQL](0) | |
buf1(0) = CountSerDeSQL(sql1.nSer + sql2.nSer, sql1.nDeSer + sql2.nDeSer) | |
} | |
def evaluate(buf: Row): Any = buf.getAs[CountSerDeSQL](0) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
scala> import scala.util.Random.nextGaussian, org.apache.spark.countSerDe._ | |
import scala.util.Random.nextGaussian | |
import org.apache.spark.countSerDe._ | |
scala> val data = sc.parallelize(Vector.fill(1000){(nextGaussian, nextGaussian)}).toDF.as[(Double, Double)] | |
data: org.apache.spark.sql.Dataset[(Double, Double)] = [_1: double, _2: double] | |
scala> val udaf = CountSerDeUDAF | |
udaf: org.apache.spark.countSerDe.CountSerDeUDAF.type = CountSerDeUDAF | |
scala> val agg = data.agg(udaf($"_1")) | |
agg: org.apache.spark.sql.DataFrame = [countserdeudaf$(_1): count-ser-de] | |
scala> agg.first.getAs[CountSerDeSQL](0) | |
res4: org.apache.spark.countSerDe.CountSerDeSQL = CountSerDeSQL(1006,1006) | |
scala> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment