Skip to content

Instantly share code, notes, and snippets.

@skp33
Forked from sadikovi/PointType.scala
Created August 30, 2018 20:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save skp33/660a1feff7ef0e7ccf1fe144bf6401ca to your computer and use it in GitHub Desktop.
Save skp33/660a1feff7ef0e7ccf1fe144bf6401ca to your computer and use it in GitHub Desktop.
Spark UDT and UDAF with custom buffer type
package org.apache.spark
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData => GenericData}
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types._
package object aggregate {
case class Point(mac: String, start: Long, end: Long) {
override def hashCode(): Int = {
31 * (31 * mac.hashCode) + start.hashCode
}
override def equals(other: Any): Boolean = other match {
case that: Point => this.mac == that.mac && this.start == that.start && this.end == that.end
case other => false
}
override def toString(): String = {
s"${getClass.getSimpleName}($mac, start=$start, end=$end)"
}
}
@SQLUserDefinedType(udt = classOf[BufferType])
type Buffer = ArrayBuffer[Point]
private[spark] class BufferType extends UserDefinedType[Buffer] {
def sqlType: DataType = ArrayType(StructType(
StructField("mac", StringType, false) ::
StructField("start", LongType, false) ::
StructField("end", LongType, false) :: Nil))
def serialize(obj: Any): Any = obj match {
case buffer: ArrayBuffer[_] =>
val data = buffer.asInstanceOf[Buffer].map { point =>
val arr = new Array[Any](3)
arr(0) = UTF8String.fromString(point.mac)
arr(1) = point.start
arr(2) = point.end
new GenericInternalRow(arr)
}
new GenericData(data)
case other => sys.error(s"Failed to serialize: $other")
}
def deserialize(datum: Any): Buffer = datum match {
case data: ArrayData =>
val buf = new Buffer()
var next: InternalRow = null
for (i <- 0 until data.array.length) {
next = data.array(i).asInstanceOf[InternalRow]
buf.append(Point(next.getString(0), next.getLong(1), next.getLong(2)))
}
buf
case other => sys.error(s"Failed to deserialize: $other")
}
def userClass: Class[Buffer] = classOf[Buffer]
}
case object BufferType extends BufferType
// == UDAF ==
class SimpleAggregate extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(
StructField("mac", StringType, true) ::
StructField("start", LongType, true) ::
StructField("end", LongType, true) :: Nil)
override def bufferSchema: StructType = StructType(
StructField("buffer", BufferType, true) :: Nil)
override def dataType: DataType = BufferType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = new Buffer()
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val buf = buffer(0).asInstanceOf[Buffer]
buf.append(Point(input.getString(0), input.getLong(1), input.getLong(2)))
buffer(0) = buf
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val buf1 = buffer1(0).asInstanceOf[Buffer]
val buf2 = buffer2(0).asInstanceOf[Buffer]
buf1.appendAll(buf2)
buffer1(0) = buf1
}
override def evaluate(buffer: Row): Any = {
buffer(0).asInstanceOf[Buffer]
}
}
implicit val ordering = new Ordering[Point] {
override def compare(x: Point, y: Point): Int = {
if (x.start == y.start) {
if (x.end == y.end) 0 else if (x.end < y.end) -1 else 1
} else {
if (x.start < y.start) -1 else 1
}
}
}
}
package org.apache.spark
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
@SQLUserDefinedType(udt = classOf[PointType])
case class Point(mac: String, start: Long, end: Long) {
override def hashCode(): Int = {
31 * (31 * mac.hashCode) + start.hashCode
}
override def equals(other: Any): Boolean = other match {
case that: Point => this.mac == that.mac && this.start == that.start && this.end == that.end
case other => false
}
override def toString(): String = {
s"${getClass.getSimpleName}($mac, start=$start, end=$end)"
}
}
class PointType extends UserDefinedType[Point] {
def sqlType: DataType = StructType(
StructField("mac", StringType, false) ::
StructField("start", LongType, false) ::
StructField("end", LongType, false) :: Nil)
def serialize(obj: Any): Any = obj match {
case c @ Point(mac, start, end) =>
println(s"Serialize: $c")
val arr = new Array[Any](3)
arr(0) = mac
arr(1) = start
arr(2) = end
new GenericArrayData(arr)
case other => sys.error(s"Failed to serialize: $other")
}
def deserialize(datum: Any): Point = datum match {
case c: ArrayData =>
println(s"Deserialize: $datum -> $c")
Point(
c.array(0).asInstanceOf[String],
c.array(1).asInstanceOf[Long],
c.array(2).asInstanceOf[Long])
case other => sys.error(s"Failed to deserialize: $other")
}
def userClass: Class[Point] = classOf[Point]
}
case object PointType extends PointType
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment