Skip to content

Instantly share code, notes, and snippets.

@joao-parana
Created October 3, 2018 13:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joao-parana/30fa32d916720d1f10577ac53fa097d8 to your computer and use it in GitHub Desktop.
Save joao-parana/30fa32d916720d1f10577ac53fa097d8 to your computer and use it in GitHub Desktop.
// UserDefinedAggregateFunction is the contract to define
// user-defined aggregate functions (UDAFs)
class MyCountUDAF extends UserDefinedAggregateFunction {
// Este método abaixo define pode ser invocado apenas assim: inputSchema(0)
// Isto é feito via inversão de dependência pelo Spark
// o retorno é um objeto StructField assim:
// StructField("id", LongType, true, {})
// o objeto StructField é do pacote org.apache.spark.sql.types
override def inputSchema: StructType = {
new StructType().add("id", LongType, nullable = true)
}
// O buffer para resultado temporário possui um único atributo
// no caso da funcionalidade de contagem.
// Este método abaixo define pode ser invocado apenas assim: bufferSchema(0)
// Isto é feito via inversão de dependência pelo Spark
// o retorno é um objeto StructField assim:
// StructField("count", LongType, true, {})
override def bufferSchema: StructType = {
new StructType().add("count", LongType, nullable = true)
}
// O método abaixo deve ser invocado sem parênteses em Scala.
// refere-se ao tipo do atributo de saida
override def dataType: DataType = LongType
override def deterministic: Boolean = true
// O método abaixo inicializa o buffer.
// Isto é feito via inversão de dependência pelo Spark
// Observe que a única coisa a ser feita é inicializar o contador com Zero.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
println(s">>> initialize (buffer: $buffer)")
// NOTE: Scala's update used under the covers
buffer(0) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
println(s">>> update (buffer: $buffer -> input: $input)")
buffer(0) = buffer.getLong(0) + 1
}
override def merge(buffer: MutableAggregationBuffer, row: Row): Unit = {
println(s">>> merge (buffer: $buffer -> row: $row)")
buffer(0) = buffer.getLong(0) + row.getLong(0)
}
override def evaluate(buffer: Row): Any = {
println(s">>> evaluate (buffer: $buffer)")
buffer.getLong(0)
}
}
// Declarando a UFAF para ser usada com a API de Dataset/DataFrame
// val myCountUDAF = new MyCountUDAF
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment