Skip to content

Instantly share code, notes, and snippets.

@lokkju
Created May 16, 2017 12:06
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lokkju/06323e88746c85b2ce4de3ea9cdef9bc to your computer and use it in GitHub Desktop.
Save lokkju/06323e88746c85b2ce4de3ea9cdef9bc to your computer and use it in GitHub Desktop.
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.types._
import scala.collection.mutable
@ExpressionDescription(
usage = "_FUNC_(expr) - Collects and returns a set of unique elements with a limit on the number of elements.")
case class CollectSetLimit(
child: Expression, limit: Int,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect {
var attemptedUpdateCount = 0
def this(child: Expression, limit: Int) = this(child, limit, 0, 0)
override def checkInputDataTypes(): TypeCheckResult = {
if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("collect_set_limit() cannot have map type data")
}
}
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "collect_set_limit"
override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty
override def update(b: InternalRow, input: InternalRow): Unit = {
attemptedUpdateCount += 1
if(buffer.size < limit) {
buffer += child.eval(input)
} else if (attemptedUpdateCount % limit == 0) {
// insert log statement, or other code, if needed
}
}
}
/**
* Collect a list of elements.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements with a limit on the number of elements.")
case class CollectListLimit(
child: Expression,
limit: Int,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect with Logging {
var attemptedUpdateCount = 0
def this(child: Expression, limit: Int) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "collect_list_limit"
override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
override def update(b: InternalRow, input: InternalRow): Unit = {
attemptedUpdateCount += 1
if(buffer.size < limit) {
buffer += child.eval(input)
} else if (attemptedUpdateCount % limit == 0) {
logWarning(s"Reached max buffer size: $attemptedUpdateCount/$limit records [${input.toString}]")
}
}
}
object collect_limit {
def collect_set_limit(e: Column, limit: Int): Column = withAggregateFunction { CollectSetLimit(e.expr, limit) }
def collect_list_limit(e: Column, limit: Int): Column = withAggregateFunction { CollectListLimit(e.expr, limit) }
private def withAggregateFunction(
func: AggregateFunction,
isDistinct: Boolean = false): Column = {
Column(func.toAggregateExpression(isDistinct))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment