Skip to content

Instantly share code, notes, and snippets.

@sarkologist
Created December 5, 2019 02:55
Show Gist options
  • Save sarkologist/a77a114f527be899e0fd6fdec56c3516 to your computer and use it in GitHub Desktop.
Save sarkologist/a77a114f527be899e0fd6fdec56c3516 to your computer and use it in GitHub Desktop.
Reordering
package gng.box.beam.pipeline.reordering
import gng.box.beam.pipeline.generic.influx.Metrics
import org.joda.time.Instant
object Common {
trait HasMetrics {
implicit val metrics: Metrics
val measurementPrefix: String
}
case class OrderedMessage(seq: Long, eventTime: Instant, reset: Boolean)
}
package gng.box.beam.pipeline.reordering
import java.lang.{Iterable => JIterable}
import java.util
import com.spotify.scio.ScioMetrics
import gng.box.beam.pipeline.generic.influx.Aggregate.{DoubleValue, LongValue}
import gng.box.beam.pipeline.generic.influx.{Aggregate, Metrics}
import gng.box.beam.pipeline.reordering.Common.HasMetrics
import gng.box.beam.pipeline.reordering.ShardedOrderedMessage.HasOrderedMessage
import gng.box.beam.utils.CoderUtil
import org.apache.beam.sdk.metrics.Counter
import org.apache.beam.sdk.state._
import org.apache.beam.sdk.transforms.DoFn
import org.apache.beam.sdk.transforms.DoFn.TimerId
import org.joda.time.{Duration, Instant}
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.JavaConverters._
import scala.collection.mutable
trait Reordering[T] {
self: HasOrderedMessage[T] with HasMetrics =>
val allowedLatenessEvent: Duration
val allowedLatenessProcessing: Duration
final val EXPECTED_SEQ = "expectedSeq"
final val MESSAGES = "deltaBag"
final val GAPS = "gaps"
final val INPUT_TIMER = "inputTimer"
final val INPUT_STALE_TIMER = "inputStaleTimer"
final val PROCESSING_STALE_TIMER = "processingStaleTimer"
final val PROCESSING_STALE_STATE = "processingStaleState"
@DoFn.StateId(GAPS)
val gapsState: StateSpec[ValueState[Map[Long, Instant]]] =
StateSpecs.value(CoderUtil.beamCoderFor[Map[Long, Instant]])
@DoFn.StateId(MESSAGES)
val messageBagState
: StateSpec[ValueState[Option[util.ArrayList[(Instant, T)]]]] =
StateSpecs.value(
CoderUtil
.beamCoderFor[Option[util.ArrayList[(Instant, T)]]])
@TimerId(INPUT_TIMER)
val inputTimerSpec: TimerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME)
@TimerId(INPUT_STALE_TIMER)
val inputStaleTimerSpec: TimerSpec =
TimerSpecs.timer(TimeDomain.PROCESSING_TIME)
@TimerId(PROCESSING_STALE_TIMER)
val processingStaleTimerSpec: TimerSpec =
TimerSpecs.timer(TimeDomain.PROCESSING_TIME)
@DoFn.StateId(EXPECTED_SEQ)
val expectedSeqSpec: StateSpec[ValueState[Long]] =
StateSpecs.value(CoderUtil.beamCoderFor[Long])
@DoFn.StateId(PROCESSING_STALE_STATE)
val processingStaleState: StateSpec[ValueState[ProcessingIsStale]] =
StateSpecs.value(CoderUtil.beamCoderFor[ProcessingIsStale])
def scheduleMessageProcessing[I, O](
c: DoFn[I, O]#ProcessContext,
messages: JIterable[T],
deltasState: ValueState[Option[util.ArrayList[(Instant, T)]]],
inputTimer: Timer,
inputStaleTimer: Timer): Unit = {
val thisBatch: Iterable[(Instant, T)] =
messages.asScala
.map((c.timestamp, _))
val times = Option(deltasState.read()).flatten
.getOrElse(new util.ArrayList())
.asScala
.map(_._1) ++ thisBatch.map(_._1)
if (thisBatch.nonEmpty) {
metrics.aggregate(
c,
Aggregate.Max,
measurementPrefix + "latency.pane",
LongValue((Instant.now.getMillis - c.timestamp.getMillis) / 1000))
}
if (times.nonEmpty) {
val minTime = times.minBy(_.getMillis)
metrics.aggregate(
c,
Aggregate.Max,
measurementPrefix + "latency.input",
LongValue((Instant.now.getMillis - minTime.getMillis) / 1000))
inputTimer.set(minTime.plus(allowedLatenessEvent))
inputStaleTimer.offset(allowedLatenessProcessing).setRelative()
}
addMessagesToState(deltasState, thisBatch)
}
def processMessages[I, O](
c: DoFn[I, O]#OnTimerContext,
whichTimer: String,
expectedSeqState: ValueState[Long],
gapsState: ValueState[Map[Long, Instant]],
messagesState: ValueState[Option[util.ArrayList[(Instant, T)]]],
inputTimer: Timer,
inputStaleTimer: Timer,
processingStaleTimer: Timer,
processingStaleState: ValueState[ProcessingIsStale],
process: T => Unit,
handleGap: (T, Long) => Long
) {
val (toProcess: mutable.PriorityQueue[(Instant, T)],
toDefer: Iterable[(Instant, T)]) =
whatToProcess(c,
whichTimer,
messagesState,
processingStaleTimer,
processingStaleState)
// assume 0L means no expected seq, i.e. before we get the first snapshot
var expectedSeq = Option(expectedSeqState.read()).getOrElse(0L)
var gaps = Option(gapsState.read).getOrElse(Map.empty)
while (toProcess.nonEmpty) {
val (ts, message) = toProcess.dequeue()
metrics.aggregate(c,
Aggregate.Max,
measurementPrefix + "latency.source",
LongValue(new Duration(
new Instant(ts.getMillis),
Instant.now()).getStandardSeconds))
val actualSeq = sequenceNumber(message)
val nextExpectedSeq = nextSequenceNumber(message)
if (expectedSeq == 0L || isSequenceNumberReset(message) || expectedSeq == actualSeq) { // no gap
metrics.count(c, measurementPrefix + "applied")
process(message)
expectedSeq = nextExpectedSeq
} else if (actualSeq > expectedSeq) { // gap
gaps += (expectedSeq -> ts)
metrics.write(c,
measurementPrefix + "gap.size",
_.addField("value", actualSeq - expectedSeq))
gapCounter.inc()
LOG.warn("sequence number gap: expected {} got {}",
expectedSeq,
actualSeq)
expectedSeq = handleGap(message, expectedSeq)
} else { // received earlier frames
gaps = checkGapLate(c, whichTimer, c.timestamp, actualSeq, gaps)
metrics.aggregate(c,
Aggregate.Max,
measurementPrefix + "duplicate",
LongValue(expectedSeq - actualSeq))
LOG.warn(
"received already processed frame, ignoring: expected {} got {}",
expectedSeq,
actualSeq)
expectedSeq = nextExpectedSeq
}
metrics.count(c, measurementPrefix + "applied")
messagesAppliedCounter.inc()
}
// write state
gapsState.write(gaps)
expectedSeqState.write(expectedSeq)
messagesState.clear(); addMessagesToState(messagesState, toDefer)
if (toDefer.nonEmpty) {
val minTime = toDefer.map(_._1).minBy(_.getMillis)
inputTimer.set(minTime.plus(allowedLatenessEvent))
}
// end write state
}
def whatToProcess[I, O](
c: DoFn[I, O]#OnTimerContext,
whichTimer: String,
messagesState: ValueState[Option[util.ArrayList[(Instant, T)]]],
processingStaleTimer: Timer,
processingStaleState: ValueState[ProcessingIsStale])
: (mutable.PriorityQueue[(Instant, T)], Iterable[(Instant, T)]) = {
if (c.timeDomain == TimeDomain.EVENT_TIME) {
metrics.aggregate(
c,
Aggregate.Max,
measurementPrefix + "latency.watermark",
LongValue((Instant.now.getMillis - c.timestamp.getMillis) / 1000))
}
val unordered: Iterable[(Instant, T)] =
Option(messagesState.read()).flatten
.getOrElse(new util.ArrayList())
.asScala
val messages: mutable.PriorityQueue[(Instant, T)] =
new mutable.PriorityQueue()(
scala.math.Ordering
.by[(Instant, T), Long](pair => sequenceNumber(pair._2))
.reverse)
messages.enqueue(unordered.toSeq: _*)
val seqAndTime: ((Instant, T)) => (Long, Long) = pair =>
(sequenceNumber(pair._2), pair._1.getMillis)
// .clone.dequeueAll necessary to to preserve seqnum order
val reordering = reorderingOf(messages.clone.dequeueAll.map(seqAndTime),
unordered.map(seqAndTime))
def predicate: ((Instant, T)) => Boolean =
if (c.timeDomain == TimeDomain.EVENT_TIME) {
(_: (Instant, T))._1
.isBefore(c.timestamp.getMillis - reordering)
} else {
Function.const(true)
} // if it is the stale timer, process everything
metrics.count(c, measurementPrefix + "timer." + whichTimer)
val toProcess = messages.filter(predicate)
val toDefer = unordered.filter(!predicate(_))
if (messages.nonEmpty) {
metrics.aggregate(
c,
aggregation = Aggregate.Min,
measurementPrefix + "process_fraction",
DoubleValue(toProcess.size.toDouble / messages.size.toDouble),
Map("timer" -> whichTimer)
)
}
setProcessingStaleTimer(whichTimer,
processingStaleTimer,
processingStaleState,
toProcess)
metrics.aggregate(c,
Aggregate.Max,
measurementPrefix + "reordering",
LongValue(reordering))
(toProcess, toDefer)
}
def setProcessingStaleTimer(
whichTimer: String,
processingStaleTimer: Timer,
processingStaleState: ValueState[ProcessingIsStale],
toProcess: mutable.PriorityQueue[(Instant, T)]): Unit = {
// clear the timer if it has fired
if (whichTimer == PROCESSING_STALE_TIMER) {
processingStaleState.clear()
}
if (toProcess.nonEmpty || // reset the timer if we have made progress
!Option(processingStaleState.read())
.exists(_.value) // initialise the timer if not already done so
) {
processingStaleTimer.offset(allowedLatenessProcessing).setRelative()
processingStaleState.write(ProcessingIsStale(true))
}
}
def checkGapLate[I, O](c: DoFn[I, O]#WindowedContext,
whichTimer: String,
trigger: Instant,
actualSeq: Long,
gaps: Map[Long, Instant])(
implicit metrics: Metrics): Map[Long, Instant] = {
gaps.get(actualSeq).foreach { ts =>
gapLateCounter.inc()
metrics.write(
c,
measurementPrefix + "gap.lateness",
_.addField("value", (trigger.getMillis - ts.getMillis) / 1000)
.addTag("timer", whichTimer)
)
}
gaps - actualSeq
}
def addMessagesToState(
messagesState: ValueState[Option[util.ArrayList[(Instant, T)]]],
toAdd: Iterable[(Instant, T)]): Unit = {
val messages =
Option(messagesState.read()).flatten.getOrElse(new util.ArrayList())
messages.addAll(toAdd.asJavaCollection)
messagesState.write(Some(messages))
}
def reorderingOf(as: Iterable[(Long, Long)],
bs: Iterable[(Long, Long)]): Long = {
val reorderings = (as zip bs)
.map {
case ((seq2, ts2), (seq1, ts1)) =>
if (seq1 == seq2) 0 else Math.abs(ts1 - ts2)
}
if (reorderings.nonEmpty) reorderings.max else 0
}
case class ProcessingIsStale(value: Boolean)
val LOG: Logger = LoggerFactory.getLogger(classOf[Reordering[_]])
val gapCounter: Counter =
ScioMetrics.counter[Reordering[_]]("gap")
val gapLateCounter: Counter =
ScioMetrics.counter[Reordering[_]]("gap_late")
val messagesAppliedCounter: Counter =
ScioMetrics.counter[Reordering[_]]("messages_applied")
}
package gng.box.beam.pipeline.reordering
import java.lang.{Iterable => JIterable}
import java.util
import com.spotify.scio.coders.Coder
import gng.box.beam.pipeline.generic.influx.Metrics
import gng.box.beam.pipeline.reordering.Common.{HasMetrics, OrderedMessage}
import gng.box.beam.pipeline.reordering.ShardedOrderedMessage.HasOrderedMessage
import org.apache.beam.sdk.state._
import org.apache.beam.sdk.transforms.DoFn
import org.apache.beam.sdk.transforms.DoFn.{
OnTimer,
ProcessElement,
StateId,
TimerId
}
import org.apache.beam.sdk.values.KV
import org.joda.time.{Duration, Instant}
import org.slf4j.{Logger, LoggerFactory}
object ReorderingTransform {
class ReorderingDoFn[K: Coder](implicit m: Metrics)
extends DoFn[KV[K, JIterable[OrderedMessage]], Unit]
with Reordering[OrderedMessage]
with HasOrderedMessage[OrderedMessage]
with HasMetrics {
override val metrics: Metrics = m
override val measurementPrefix: String = "beam.reordering."
override def sequenceNumber(o: OrderedMessage): Long = o.seq
override def nextSequenceNumber(o: OrderedMessage): Long = o.seq + 1
override def isSequenceNumberReset(o: OrderedMessage): Boolean = o.reset
override val allowedLatenessEvent: Duration = Duration.standardMinutes(5)
override val allowedLatenessProcessing: Duration =
Duration.standardMinutes(10)
@ProcessElement
def processElement(
c: DoFn[KV[K, JIterable[OrderedMessage]], Unit]#ProcessContext,
@StateId(MESSAGES) messagesState: ValueState[
Option[util.ArrayList[(Instant, OrderedMessage)]]],
@TimerId(INPUT_TIMER) inputTimer: Timer,
@TimerId(INPUT_STALE_TIMER) inputStaleTimer: Timer,
) {
scheduleMessageProcessing(c,
c.element.getValue,
messagesState,
inputTimer,
inputStaleTimer)
}
@OnTimer(INPUT_TIMER)
def inputCallback(
c: OnTimerContext,
@StateId(EXPECTED_SEQ) lastAppliedSeqState: ValueState[Long],
@StateId(GAPS) gapsState: ValueState[Map[Long, Instant]],
@StateId(MESSAGES) deltasState: ValueState[
Option[util.ArrayList[(Instant, OrderedMessage)]]],
@TimerId(INPUT_TIMER) inputTimer: Timer,
@TimerId(INPUT_STALE_TIMER) inputStaleTimer: Timer,
@TimerId(PROCESSING_STALE_TIMER) processingStaleTimer: Timer,
@StateId(PROCESSING_STALE_STATE) processingStaleState: ValueState[
ProcessingIsStale],
): Unit = {
processMessages(
c,
INPUT_TIMER,
lastAppliedSeqState,
gapsState,
deltasState,
inputTimer,
inputStaleTimer,
processingStaleTimer,
processingStaleState,
process,
handleGap
)
}
@OnTimer(INPUT_STALE_TIMER)
def inputStaleCallback(
c: DoFn[KV[K, JIterable[OrderedMessage]], Unit]#OnTimerContext,
@StateId(EXPECTED_SEQ) lastAppliedSeqState: ValueState[Long],
@StateId(GAPS) gapsState: ValueState[Map[Long, Instant]],
@StateId(MESSAGES) deltasState: ValueState[
Option[util.ArrayList[(Instant, OrderedMessage)]]],
@TimerId(INPUT_TIMER) inputTimer: Timer,
@TimerId(INPUT_STALE_TIMER) inputStaleTimer: Timer,
@TimerId(PROCESSING_STALE_TIMER) processingStaleTimer: Timer,
@StateId(PROCESSING_STALE_STATE) processingStaleState: ValueState[
ProcessingIsStale],
): Unit = {
processMessages(
c,
INPUT_STALE_TIMER,
lastAppliedSeqState,
gapsState,
deltasState,
inputTimer,
inputStaleTimer,
processingStaleTimer,
processingStaleState,
process,
handleGap
)
}
@OnTimer(PROCESSING_STALE_TIMER)
def processingStaleCallback(
c: OnTimerContext,
@StateId(EXPECTED_SEQ) lastAppliedSeqState: ValueState[Long],
@StateId(GAPS) gapsState: ValueState[Map[Long, Instant]],
@StateId(MESSAGES) deltasState: ValueState[
Option[util.ArrayList[(Instant, OrderedMessage)]]],
@TimerId(INPUT_TIMER) inputTimer: Timer,
@TimerId(INPUT_STALE_TIMER) inputStaleTimer: Timer,
@TimerId(PROCESSING_STALE_TIMER) processingStaleTimer: Timer,
@StateId(PROCESSING_STALE_STATE) processingStaleState: ValueState[
ProcessingIsStale],
): Unit = {
processMessages(
c,
PROCESSING_STALE_TIMER,
lastAppliedSeqState,
gapsState,
deltasState,
inputTimer,
inputStaleTimer,
processingStaleTimer,
processingStaleState,
process,
handleGap
)
}
def process(message: OrderedMessage): Unit = println("applying: " + message)
def handleGap(message: OrderedMessage, expected: Long): Long = {
println("gap! seen " + message.seq + " expected " + expected)
message.seq + 1
}
}
val LOG: Logger = LoggerFactory.getLogger(classOf[ReorderingDoFn[_]])
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment