Skip to content

Instantly share code, notes, and snippets.

@sadikovi
Last active June 12, 2020 11:05
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save sadikovi/7608c8c7eb5d7fe69a1a to your computer and use it in GitHub Desktop.
Save sadikovi/7608c8c7eb5d7fe69a1a to your computer and use it in GitHub Desktop.
UDAF for generating collection of values (for a specific limit)
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{ArrayType, LongType, DataType, StructType, StructField}
class CollectionFunction(private val limit: Int) extends UserDefinedAggregateFunction {
def inputSchema: StructType =
StructType(StructField("value", LongType, false) :: Nil)
def bufferSchema: StructType =
StructType(StructField("list", ArrayType(LongType, true), true) :: Nil)
override def dataType: DataType = ArrayType(LongType, true)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = IndexedSeq[Long]()
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (buffer != null) {
val seq = buffer(0).asInstanceOf[IndexedSeq[Long]]
if (seq.length < limit) {
buffer(0) = input.getAs[Long](0) +: seq
} else {
buffer(0) = null
}
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if (buffer1(0) != null && buffer2 != null) {
val seq1 = buffer1(0).asInstanceOf[IndexedSeq[Long]]
val seq2 = buffer2(0).asInstanceOf[IndexedSeq[Long]]
if (seq1.length + seq2.length <= limit) {
buffer1(0) = seq1 ++ seq2
} else {
buffer1(0) = null
}
}
}
def evaluate(buffer: Row): Any = {
if (buffer(0) == null) {
IndexedSeq[Long]()
} else {
buffer(0).asInstanceOf[IndexedSeq[Long]]
}
}
}
import sqlContext.implicits._
val a = sc.parallelize(0L to 20L).map(x => (x, x % 4)).toDF("value", "group")
val cl = new CollectionFunction(5)
val df = a.groupBy("group").agg(cl($"value").as("list")).cache()
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{ArrayType, LongType, IntegerType, BinaryType, DataType, StructType, StructField}
/**
* UDAF for timestamps collection. Also requires tslimit that restricts maximum number of
* timestamps. If record exceeds number of timestamps comparing to `tslimit`, overall collection is
* reset to empty. We do not apply old logic where we reset sequence if one of the merging parts is
* empty, instead we just do not reset counter.
* Performance update: now we are using IndexedSeq [Vector] which is backed by tree
* (32 children nodes), nullify buffer when we know that we do not need it anymore, and remove
* another column with count.
*/
class TimeCollectionFunction(
private val tslimit: Int
) extends UserDefinedAggregateFunction {
def inputSchema: StructType =
StructType(StructField("value", LongType, false) :: Nil)
def bufferSchema: StructType =
StructType(StructField("list", BufferType, true) :: Nil)
override def dataType: DataType = BufferType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = new ArrayBuffer[Long]()
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (buffer != null && !input.isNullAt(0)) {
var buf = buffer(0).asInstanceOf[ArrayBuffer[Long]]
if (buf.length < tslimit) {
buf.append(input.getAs[Long](0))
buffer(0) = buf
} else {
buffer(0) = null
buf = null
}
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if (buffer1 != null && buffer2(0) != null) {
var buf1 = buffer1(0).asInstanceOf[ArrayBuffer[Long]]
var buf2 = buffer2(0).asInstanceOf[ArrayBuffer[Long]]
if (buf1.length + buf2.length <= tslimit) {
buf1.appendAll(buf2)
buffer1(0) = buf1
} else {
buffer1(0) = null
buf1 = null
buf2 = null
}
}
}
def evaluate(buffer: Row): Any = {
if (buffer(0) == null) {
new ArrayBuffer[Long]()
} else {
buffer(0).asInstanceOf[ArrayBuffer[Long]]
}
}
}
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.types.{UserDefinedType, ArrayType, SQLUserDefinedType}
type Buffer = ArrayBuffer[Long]
class BufferType extends UserDefinedType[Buffer] {
def sqlType: DataType = ArrayType(LongType, false)
def serialize(obj: Any): Any = obj match {
case c: Buffer => c.toSeq
case other => throw new UnsupportedOperationException(s"Cannot serialize object ${other}")
}
/** Convert a SQL datum to the user type */
def deserialize(datum: Any): Buffer = datum match {
case a: Seq[_] => a.toBuffer.asInstanceOf[Buffer]
case other => throw new UnsupportedOperationException(s"Cannot deserialize object ${other}")
}
def userClass: Class[Buffer] = {
classOf[Buffer]
}
override def defaultSize: Int = 1500
}
case object BufferType extends BufferType
val schema = StructType(StructField("list", BufferType, false) :: Nil)
import org.apache.spark.sql.Row
val rdd = sc.parallelize(0 to 10).map(x => {
val seq = new Buffer()
seq.appendAll(Seq(1, 2, 3, 4, 5))
Row(seq)
})
val df = sqlContext.createDataFrame(rdd, schema)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment