Skip to content

Instantly share code, notes, and snippets.

@YuvalItzchakov
Created August 23, 2021 11:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YuvalItzchakov/5cc7b076d31d73e5c1f9b7b72b3c624b to your computer and use it in GitHub Desktop.
Save YuvalItzchakov/5cc7b076d31d73e5c1f9b7b72b3c624b to your computer and use it in GitHub Desktop.
An attempt to create a generic ARRAY_AGG function for Flink
import org.apache.flink.table.api.DataTypes
import org.apache.flink.table.api.dataview.ListView
import org.apache.flink.table.catalog.DataTypeFactory
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.types.inference.{ InputTypeStrategies, TypeInference }
import scala.collection.JavaConverters._
import scala.compat.java8.OptionConverters.RichOptionForJava8
import scala.reflect.ClassTag
class ArrayAggAccumulator[T](var listView: ListView[T])
class ArrayAgg[T: ClassTag] extends AggregateFunction[Array[T], ArrayAggAccumulator[T]] {
override def getTypeInference(typeFactory: DataTypeFactory): TypeInference =
TypeInference
.newBuilder()
.inputTypeStrategy(InputTypeStrategies.SPECIFIC_FOR_ARRAY)
.accumulatorTypeStrategy { ctxt =>
ctxt.getArgumentDataTypes.asScala.headOption.map { argType =>
DataTypes.STRUCTURED(
classOf[ArrayAggAccumulator[T]],
DataTypes.FIELD("listView", ListView.newListViewDataType(argType))
)
}.asJava
}
.outputTypeStrategy { ctxt =>
ctxt.getArgumentDataTypes.asScala.headOption
.map(argType =>
DataTypes
.ARRAY(argType)
)
.asJava
}
.build()
override def getValue(accumulator: ArrayAggAccumulator[T]): Array[T] =
accumulator.listView.getList.asScala.toArray
override def createAccumulator(): ArrayAggAccumulator[T] =
new ArrayAggAccumulator[T](new ListView[T]())
def resetAccumulator(acc: ArrayAggAccumulator[T]): Unit = acc.listView.clear()
def accumulate(acc: ArrayAggAccumulator[T], value: T): Unit =
acc.listView.add(value)
def retract(acc: ArrayAggAccumulator[T], value: T): Unit = {
acc.listView.remove(value)
()
}
def merge(acc: ArrayAggAccumulator[T], it: Iterable[ArrayAggAccumulator[T]]): Unit =
it.foreach(accu => acc.listView.addAll(accu.listView.getList))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment