Skip to content

Instantly share code, notes, and snippets.

@EntilZha
Created July 9, 2016 13:07
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save EntilZha/3951769a011389fef25e930258c20a2a to your computer and use it in GitHub Desktop.
Save EntilZha/3951769a011389fef25e930258c20a2a to your computer and use it in GitHub Desktop.
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import sqlContext.implicits._
import org.apache.spark.sql.types.{StructType, StructField, DataType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType, StringType, BinaryType, BooleanType, TimestampType, DateType, ArrayType}
class MinBy(valueType: DataType, minType: DataType) extends UserDefinedAggregateFunction {
def inputSchema: StructType = StructType(StructField("value", valueType) :: StructField("minCol", minType) :: Nil)
def bufferSchema: StructType = StructType(StructField("value", valueType) :: StructField("minCol", minType) :: Nil)
def dataType: DataType = valueType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = null
buffer(1) = null
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (buffer.get(1) == null) {
buffer(0) = input.get(0)
buffer(1) = input.get(1)
} else if (input.get(1) != null && input.getLong(1) < buffer.getLong(1)) {
buffer(0) = input.get(0)
buffer(1) = input.get(1)
}
}
def merge(leftBuffer: MutableAggregationBuffer, rightBuffer: Row): Unit = {
if (leftBuffer.get(1) != null && rightBuffer.get(1) != null) {
if (rightBuffer.getLong(1) < leftBuffer.getLong(1)) {
leftBuffer(0) = rightBuffer.get(0)
leftBuffer(1) = rightBuffer.get(1)
}
} else if (leftBuffer.get(1) == null) {
leftBuffer(0) = rightBuffer.get(0)
leftBuffer(1) = rightBuffer.get(1)
}
}
def evaluate(buffer: Row): Any = {
buffer.getString(0)
}
}
case class Person(family: String, name: String, age: Long)
val df = Seq(
Person("Chumich", "Andy", 32),
Person("Rodriguez", "Pedro", 25),
Person("Brown", "Grace", 17),
Person("Rodriguez", "Fritz", 23),
Person("Brown", "Tyler", 15)
).toDF()
val mb = new MinBy(StringType, LongType)
df.groupBy("family").agg(mb($"name", $"age")).show
// +---------+---------------+
// | family|MinBy(name,age)|
// +---------+---------------+
// | Chumich| Andy|
// | Brown| Tyler|
// |Rodriguez| Fritz|
// +---------+---------------+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment