Created
June 25, 2019 19:00
-
-
Save piyushnarang/fe562060789ffeb01d59dcc3da375849 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.lang.{Double => JDouble, Long => JLong} | |
import org.apache.flink.api.common.typeinfo.TypeInformation | |
import org.apache.flink.api.scala._ | |
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} | |
import org.apache.flink.api.java.typeutils.TupleTypeInfo | |
import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment} | |
import org.apache.flink.table.api.{TableEnvironment, Types} | |
import org.apache.flink.table.functions.{AggregateFunction, FunctionContext} | |
import org.apache.flink.api.scala._ | |
import org.apache.flink.table.api.scala._ | |
import org.apache.flink.types.Row | |
class EventRateAccumulator extends JTuple2[JLong, JLong] { | |
f0 = 0L | |
f1 = 0L | |
} | |
class Event(t0: Int, t1: String) extends JTuple2[Int, String](t0, t1) | |
/** | |
* To compute counters like click-through-rate (clicks / displays) | |
*/ | |
class EventRate(val numeratorEvent: String, val denominatorEvent: String) extends AggregateFunction[JDouble, EventRateAccumulator] { | |
override def open(context: FunctionContext): Unit = { | |
super.open(context) | |
println("Open called!!") // this doesn't get invoked | |
} | |
override def createAccumulator(): EventRateAccumulator = new EventRateAccumulator | |
override def getValue(accumulator: EventRateAccumulator): JDouble = { | |
if (accumulator.f1 > 0) accumulator.f0.toDouble / accumulator.f1.toDouble | |
else -1.0 | |
} | |
def accumulate(accumulator: EventRateAccumulator, eventType: String): Unit = { | |
if (eventType == numeratorEvent) | |
accumulator.f0 += 1 | |
else if (eventType == denominatorEvent) | |
accumulator.f1 += 1 | |
} | |
def resetAccumulator(accumulator: EventRateAccumulator): Unit = { | |
accumulator.f0 = 0L | |
accumulator.f1 = 0L | |
} | |
override def getResultType: TypeInformation[JDouble] = Types.DOUBLE | |
} | |
object TestAggregateOpenFn { | |
def main(args: Array[String]): Unit = { | |
val env = ExecutionEnvironment.getExecutionEnvironment | |
val tEnv = TableEnvironment.getTableEnvironment(env) | |
implicit val typeInfo = new TupleTypeInfo[Event](Types.INT, Types.STRING) | |
val input: DataSet[Event] = env.fromElements[Event]( | |
new Event(123, "Click"), | |
new Event(123, "Display"), | |
new Event(456, "Click") | |
) | |
tEnv.registerDataSet[Event]("advertiser_event", input, 'partnerId, 'eventType) | |
tEnv.registerFunction("CTR", new EventRate("Click", "Display")) | |
val table = tEnv.sqlQuery("SELECT partnerId, CTR(eventType) FROM advertiser_event GROUP BY partnerId") | |
table.toDataSet[Row](table.getSchema.toRowType).print() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment