Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save HeartSaVioR/9a3aeeef0f1d8ee97516743308b14cd6 to your computer and use it in GitHub Desktop.
Save HeartSaVioR/9a3aeeef0f1d8ee97516743308b14cd6 to your computer and use it in GitHub Desktop.
Implementation of session window with event time and watermark via flatMapGroupsWithState, and SPARK-10816
case class SessionInfo(sessionStartTimestampMs: Long,
sessionEndTimestampMs: Long,
numEvents: Int) {
/** Duration of the session, between the first and last events + session gap */
def durationMs: Long = sessionEndTimestampMs - sessionStartTimestampMs
}
case class SessionUpdate(id: String,
sessionStartTimestampSecs: Long,
sessionEndTimestampSecs: Long,
durationSecs: Long,
numEvents: Int)
test("session window - flatMapGroupsWithState") {
import java.sql.Timestamp
val inputData = MemoryStream[(String, Long)]
val events = inputData.toDF()
.select($"_1".as("value"), $"_2".as("timestamp"))
.as[(String, Timestamp)]
.flatMap { case (v, timestamp) =>
v.split(" ").map { word => (word, timestamp) }
}
.as[(String, Timestamp)]
.withWatermark("_2", "30 seconds")
val outputMode = OutputMode.Append() // below stateFunc also supports OutputMode.Update
val sessionGapMills = 10 * 1000
val stateFunc: (String, Iterator[(String, Timestamp)], GroupState[List[SessionInfo]])
=> Iterator[SessionUpdate] =
(sessionId: String, events: Iterator[(String, Timestamp)],
state: GroupState[List[SessionInfo]]) => {
def handleEvict(sessionId: String, state: GroupState[List[SessionInfo]])
: Iterator[SessionUpdate] = {
state.getOption match {
case Some(lst) =>
// assuming sessions are sorted by session start timestamp
val (evicted, kept) = lst.span {
s => s.sessionEndTimestampMs < state.getCurrentWatermarkMs()
}
if (kept.isEmpty) {
state.remove()
} else {
state.update(kept)
state.setTimeoutTimestamp(kept.head.sessionEndTimestampMs)
}
outputMode match {
case s if s == OutputMode.Append() =>
evicted.iterator.map(si => SessionUpdate(sessionId,
si.sessionStartTimestampMs / 1000,
si.sessionEndTimestampMs / 1000,
si.durationMs / 1000, si.numEvents))
case s if s == OutputMode.Update() => Seq.empty[SessionUpdate].iterator
case s => throw new UnsupportedOperationException(s"Not supported output mode $s")
}
case None =>
state.remove()
Seq.empty[SessionUpdate].iterator
}
}
def mergeSession(session1: SessionInfo, session2: SessionInfo): SessionInfo = {
SessionInfo(
sessionStartTimestampMs = Math.min(session1.sessionStartTimestampMs,
session2.sessionStartTimestampMs),
sessionEndTimestampMs = Math.max(session1.sessionEndTimestampMs,
session2.sessionEndTimestampMs),
numEvents = session1.numEvents + session2.numEvents)
}
def handleEvents(sessionId: String, events: Iterator[(String, Timestamp)],
state: GroupState[List[SessionInfo]]): Iterator[SessionUpdate] = {
import java.{util => ju}
import scala.collection.mutable
import collection.JavaConverters._
// we assume only previous sessions are sorted: events are not guaranteed to be sorted.
// we also assume the number of sessions for each key is not huge, which is valid
// unless end users set huge watermark delay as well as smaller session gap.
val newSessions: ju.LinkedList[SessionInfo] = state.getOption match {
case Some(lst) => new ju.LinkedList[SessionInfo](lst.asJava)
case None => new ju.LinkedList[SessionInfo]()
}
// this is to track the change of sessions for update mode
// if you define "update" as returning whole new sessions on given key,
// you can remove this and logic to track sessions
val updatedSessions = new mutable.ListBuffer[SessionInfo]()
while (events.hasNext) {
val ev = events.next()
// convert each event to one of session window
val event = SessionInfo(ev._2.getTime, ev._2.getTime + sessionGapMills, 1)
// find matched session
var index = 0
var updated = false
while (!updated && index < newSessions.size()) {
val session = newSessions.get(index)
if (event.sessionEndTimestampMs < session.sessionStartTimestampMs) {
// no matched session, and following sessions will not be matched
newSessions.add(index, event)
updated = true
updatedSessions += event
} else if (event.sessionStartTimestampMs > session.sessionEndTimestampMs) {
// continue to next session
index += 1
} else {
// matched: update session
var newSession = session.copy(
sessionStartTimestampMs = Math.min(session.sessionStartTimestampMs,
event.sessionStartTimestampMs),
sessionEndTimestampMs = Math.max(session.sessionEndTimestampMs,
event.sessionEndTimestampMs),
numEvents = session.numEvents + event.numEvents)
// we are going to replace previous session with new session, so previous session should be removed from updated sessions
// same occurs below if statements
updatedSessions -= session
// check for a chance to concatenate new session and next session
if (index + 1 < newSessions.size()) {
val nextSession = newSessions.get(index + 1)
if (newSession.sessionEndTimestampMs <= nextSession.sessionStartTimestampMs) {
newSession = mergeSession(newSession, nextSession)
updatedSessions -= nextSession
newSessions.remove(index + 1)
}
}
// check for a chance to concatenate new session and previous session
if (index - 1 >= 0) {
val prevSession = newSessions.get(index - 1)
if (newSession.sessionEndTimestampMs <= prevSession.sessionStartTimestampMs) {
newSession = mergeSession(newSession, prevSession)
updatedSessions -= prevSession
newSessions.remove(index - 1)
index -= 1
}
}
newSessions.set(index, newSession)
updatedSessions += newSession
updated = true
}
}
if (!updated) {
// none matched so far, add to last
newSessions.addLast(event)
updatedSessions += event
}
}
val newSessionsForScala = newSessions.asScala.toList
state.update(newSessionsForScala)
// there must be at least one session available
// set timeout to earliest sessions' session end: we will traverse and evict sessions
state.setTimeoutTimestamp(newSessionsForScala.head.sessionEndTimestampMs)
outputMode match {
case s if s == OutputMode.Update() =>
updatedSessions.iterator.map(si =>
SessionUpdate(sessionId, si.sessionStartTimestampMs / 1000,
si.sessionEndTimestampMs / 1000, si.durationMs / 1000, si.numEvents))
case s if s == OutputMode.Append() => Seq.empty[SessionUpdate].iterator
case s => throw new UnsupportedOperationException(s"Not supported output mode $s")
}
}
if (state.hasTimedOut) {
handleEvict(sessionId, state)
} else {
handleEvents(sessionId, events, state)
}
}
val sessionUpdates = events
.groupByKey(event => event._1)
.flatMapGroupsWithState[List[SessionInfo], SessionUpdate](
outputMode, timeoutConf = GroupStateTimeout.EventTimeTimeout())(stateFunc)
// codes for verifying output place here
}
// below test code is providing same result as above
test("session window - session_window (SPARK-10816)") {
val inputData = MemoryStream[(String, Long)]
// Split the lines into words, treat words as sessionId of events
val events = inputData.toDF()
.select($"_1".as("value"), $"_2".as("timestamp"))
.withColumn("eventTime", $"timestamp".cast("timestamp"))
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
.withWatermark("eventTime", "10 seconds")
val sessionUpdates = events
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
.agg(count("*").as("numEvents"))
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
"numEvents")
// codes for verifying output place here
}
/*
// for verifying append mode
testStream(sessionUpdates, outputMode)(
AddData(inputData,
("hello world spark streaming", 40L),
("world hello structured streaming", 41L)
),
// watermark: 11
// current sessions
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// placing new sessions "before" previous sessions
AddData(inputData, ("spark streaming", 25L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// late event which session's end 10 would be later than watermark 11: should be dropped
AddData(inputData, ("spark streaming", 0L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// concatenating multiple previous sessions into one
AddData(inputData, ("spark streaming", 30L)),
// watermark: 11
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// placing new sessions after previous sessions
AddData(inputData, ("hello apache spark", 60L)),
// watermark: 30
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1),
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1)
CheckNewAnswer(
),
AddData(inputData, ("structured streaming", 90L)),
// watermark: 60
// current sessions
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1),
// ("structured", 90, 100, 10, 1),
// ("streaming", 90, 100, 10, 1)
CheckNewAnswer(
("spark", 25, 50, 25, 3),
("streaming", 25, 51, 26, 4),
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("structured", 41, 51, 10, 1)
)
)
*/
/*
// for verifying update mode
testStream(sessionUpdates, outputMode)(
AddData(inputData,
("hello world spark streaming", 40L),
("world hello structured streaming", 41L)
),
// watermark: 11
// current sessions
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("streaming", 40, 51, 11, 2),
("spark", 40, 50, 10, 1),
("structured", 41, 51, 10, 1)
),
// placing new sessions "before" previous sessions
AddData(inputData, ("spark streaming", 25L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("spark", 25, 35, 10, 1),
("streaming", 25, 35, 10, 1)
),
// late event which session's end 10 would be later than watermark 11: should be dropped
AddData(inputData, ("spark streaming", 0L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// concatenating multiple previous sessions into one
AddData(inputData, ("spark streaming", 30L)),
// watermark: 11
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("spark", 25, 50, 25, 3),
("streaming", 25, 51, 26, 4)
),
// placing new sessions after previous sessions
AddData(inputData, ("hello apache spark", 60L)),
// watermark: 30
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1),
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1)
CheckNewAnswer(
("hello", 60, 70, 10, 1),
("apache", 60, 70, 10, 1),
("spark", 60, 70, 10, 1)
),
AddData(inputData, ("structured streaming", 90L)),
// watermark: 60
// current sessions
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1),
// ("structured", 90, 100, 10, 1),
// ("streaming", 90, 100, 10, 1)
// evicted
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("structured", 90, 100, 10, 1),
("streaming", 90, 100, 10, 1)
)
)
*/
@TimidLion
Copy link

안녕하세요, 좋은 코드 감사합니다.
저는 spark에 대해서 배우기 시작한지 얼마 안 된 학생입니다.
다름이 아니라, 주석처리된 223번째 줄의 testStream의 경우 spark의 StreamTest에서 나온 것이 맞나요?
어떻게 가져다 쓰셨는지가 궁금해서 여쭤봅니다.

저의 경우에는
import org.apache.spark.sql.streaming.StreamTest
class StreamTestClass extends StreamTest
라고 해줘도 import 할 수가 없었습니다.
부족하지만 작은 도움이나마 주시면 매우 감사드리겠습니다.

@HeartSaVioR
Copy link
Author

안녕하세요! :)
spark-sql artifact 의 test-jar 를 dependency 로 추가하시면 됩니다.

    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-sql_${scala.binary.version}</artifactId>
      <version>${spark.version}</version>
      <type>test-jar</type>
      <scope>test</scope>
    </dependency>

@TimidLion
Copy link

빠른 답변 감사합니다!
즐거운 한가위 보내세요 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment