Last active
June 12, 2020 11:05
-
-
Save sadikovi/7608c8c7eb5d7fe69a1a to your computer and use it in GitHub Desktop.
UDAF for generating collection of values (for a specific limit)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} | |
import org.apache.spark.sql.types.{ArrayType, LongType, DataType, StructType, StructField} | |
class CollectionFunction(private val limit: Int) extends UserDefinedAggregateFunction { | |
def inputSchema: StructType = | |
StructType(StructField("value", LongType, false) :: Nil) | |
def bufferSchema: StructType = | |
StructType(StructField("list", ArrayType(LongType, true), true) :: Nil) | |
override def dataType: DataType = ArrayType(LongType, true) | |
def deterministic: Boolean = true | |
def initialize(buffer: MutableAggregationBuffer): Unit = { | |
buffer(0) = IndexedSeq[Long]() | |
} | |
def update(buffer: MutableAggregationBuffer, input: Row): Unit = { | |
if (buffer != null) { | |
val seq = buffer(0).asInstanceOf[IndexedSeq[Long]] | |
if (seq.length < limit) { | |
buffer(0) = input.getAs[Long](0) +: seq | |
} else { | |
buffer(0) = null | |
} | |
} | |
} | |
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { | |
if (buffer1(0) != null && buffer2 != null) { | |
val seq1 = buffer1(0).asInstanceOf[IndexedSeq[Long]] | |
val seq2 = buffer2(0).asInstanceOf[IndexedSeq[Long]] | |
if (seq1.length + seq2.length <= limit) { | |
buffer1(0) = seq1 ++ seq2 | |
} else { | |
buffer1(0) = null | |
} | |
} | |
} | |
def evaluate(buffer: Row): Any = { | |
if (buffer(0) == null) { | |
IndexedSeq[Long]() | |
} else { | |
buffer(0).asInstanceOf[IndexedSeq[Long]] | |
} | |
} | |
} | |
import sqlContext.implicits._ | |
val a = sc.parallelize(0L to 20L).map(x => (x, x % 4)).toDF("value", "group") | |
val cl = new CollectionFunction(5) | |
val df = a.groupBy("group").agg(cl($"value").as("list")).cache() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} | |
import org.apache.spark.sql.types.{ArrayType, LongType, IntegerType, BinaryType, DataType, StructType, StructField} | |
/** | |
* UDAF for timestamps collection. Also requires tslimit that restricts maximum number of | |
* timestamps. If record exceeds number of timestamps comparing to `tslimit`, overall collection is | |
* reset to empty. We do not apply old logic where we reset sequence if one of the merging parts is | |
* empty, instead we just do not reset counter. | |
* Performance update: now we are using IndexedSeq [Vector] which is backed by tree | |
* (32 children nodes), nullify buffer when we know that we do not need it anymore, and remove | |
* another column with count. | |
*/ | |
class TimeCollectionFunction( | |
private val tslimit: Int | |
) extends UserDefinedAggregateFunction { | |
def inputSchema: StructType = | |
StructType(StructField("value", LongType, false) :: Nil) | |
def bufferSchema: StructType = | |
StructType(StructField("list", BufferType, true) :: Nil) | |
override def dataType: DataType = BufferType | |
def deterministic: Boolean = true | |
def initialize(buffer: MutableAggregationBuffer): Unit = { | |
buffer(0) = new ArrayBuffer[Long]() | |
} | |
def update(buffer: MutableAggregationBuffer, input: Row): Unit = { | |
if (buffer != null && !input.isNullAt(0)) { | |
var buf = buffer(0).asInstanceOf[ArrayBuffer[Long]] | |
if (buf.length < tslimit) { | |
buf.append(input.getAs[Long](0)) | |
buffer(0) = buf | |
} else { | |
buffer(0) = null | |
buf = null | |
} | |
} | |
} | |
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { | |
if (buffer1 != null && buffer2(0) != null) { | |
var buf1 = buffer1(0).asInstanceOf[ArrayBuffer[Long]] | |
var buf2 = buffer2(0).asInstanceOf[ArrayBuffer[Long]] | |
if (buf1.length + buf2.length <= tslimit) { | |
buf1.appendAll(buf2) | |
buffer1(0) = buf1 | |
} else { | |
buffer1(0) = null | |
buf1 = null | |
buf2 = null | |
} | |
} | |
} | |
def evaluate(buffer: Row): Any = { | |
if (buffer(0) == null) { | |
new ArrayBuffer[Long]() | |
} else { | |
buffer(0).asInstanceOf[ArrayBuffer[Long]] | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.collection.mutable.ArrayBuffer | |
import org.apache.spark.sql.types.{UserDefinedType, ArrayType, SQLUserDefinedType} | |
type Buffer = ArrayBuffer[Long] | |
class BufferType extends UserDefinedType[Buffer] { | |
def sqlType: DataType = ArrayType(LongType, false) | |
def serialize(obj: Any): Any = obj match { | |
case c: Buffer => c.toSeq | |
case other => throw new UnsupportedOperationException(s"Cannot serialize object ${other}") | |
} | |
/** Convert a SQL datum to the user type */ | |
def deserialize(datum: Any): Buffer = datum match { | |
case a: Seq[_] => a.toBuffer.asInstanceOf[Buffer] | |
case other => throw new UnsupportedOperationException(s"Cannot deserialize object ${other}") | |
} | |
def userClass: Class[Buffer] = { | |
classOf[Buffer] | |
} | |
override def defaultSize: Int = 1500 | |
} | |
case object BufferType extends BufferType | |
val schema = StructType(StructField("list", BufferType, false) :: Nil) | |
import org.apache.spark.sql.Row | |
val rdd = sc.parallelize(0 to 10).map(x => { | |
val seq = new Buffer() | |
seq.appendAll(Seq(1, 2, 3, 4, 5)) | |
Row(seq) | |
}) | |
val df = sqlContext.createDataFrame(rdd, schema) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment