Skip to content

Instantly share code, notes, and snippets.

@hamnis
Forked from izeigerman/SseClient.scala
Last active August 26, 2020 08:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hamnis/6b134112e083ff89c0ce44662e831863 to your computer and use it in GitHub Desktop.
Save hamnis/6b134112e083ff89c0ce44662e831863 to your computer and use it in GitHub Desktop.
Complete SSE client Implementation for Scala using http4s.
import cats.effect.Timer
import fs2.{Pull, RaiseThrowable, Stream}
import org.http4s._
import org.http4s.ServerSentEvent.EventId
import org.http4s.client.Client
import org.http4s.headers.{Accept, `Cache-Control`}
import scala.concurrent.duration._
import SseClient._
final class SseClient[F[_]] private (
httpClient: Client[F],
maxRetries: Int,
initialRetryInterval: FiniteDuration
)(implicit timer: Timer[F], rt: RaiseThrowable[F]) {
def stream(uri: Uri): Stream[F, ServerSentEvent] = autoReconnectStream(uri)
private def autoReconnectStream(uri: Uri): Stream[F, ServerSentEvent] = {
def go(s: Stream[F, ServerSentEvent], metadata: SseMetadata): Pull[F, ServerSentEvent, Unit] =
s.pull.uncons1
.flatMap {
case Some((event, tail)) =>
val newMetadata = metadata.copy(
eventId = event.id.fold(metadata.eventId)(id => if (id != EventId.reset) Some(id.value) else None),
retry = event.retry.fold(metadata.retry)(r => Some(r.millis))
)
val metadataChanged = newMetadata != metadata
Pull.output1(event) >> {
// Add a new error handler only if the metadata changed.
if (metadataChanged) goWithErrorHandled(tail, newMetadata) else go(tail, metadata)
}
case None =>
reconnect(metadata, None)
}
def reconnect(metadata: SseMetadata, err: Option[Throwable]): Pull[F, ServerSentEvent, Unit] = {
val delay = metadata.retry.getOrElse(initialRetryInterval)
val newMetadata = metadata.copy(attempts = metadata.attempts + 1)
if (newMetadata.attempts > maxRetries) Pull.raiseError[F](MaxRetriesReached(err))
else goWithErrorHandled(newStream(uri, newMetadata.eventId).delayBy(delay), newMetadata)
}
def goWithErrorHandled(
stream: Stream[F, ServerSentEvent],
metadata: SseMetadata
): Pull[F, ServerSentEvent, Unit] =
go(stream, metadata).handleErrorWith {
case e: MaxRetriesReached => Pull.raiseError[F](e)
case other => reconnect(metadata, Some(other))
}
goWithErrorHandled(newStream(uri, None), SseMetadata(None, None, 1)).stream
}
private def newStream(uri: Uri, eventId: Option[String]): Stream[F, ServerSentEvent] =
httpClient
.stream(
Request(
uri = uri,
headers = BaseHeaders ++ Headers(eventId.map(Header("Last-Event-ID", _)).toList)
)
)
//todo: Check if we actually got SSE before decoding.
.flatMap(_.body)
.through(ServerSentEvent.decoder[F])
}
object SseClient {
val DefaultRetryInterval: FiniteDuration = 5.seconds
val DefaultMaxRetries: Int = 10
def apply[F[_]](
httpClient: Client[F],
maxRetries: Int = DefaultMaxRetries,
initialRetryInterval: FiniteDuration = DefaultRetryInterval
)(implicit timer: Timer[F], rt: RaiseThrowable[F]): SseClient[F] =
new SseClient(httpClient, maxRetries, initialRetryInterval)
private case class SseMetadata(eventId: Option[String], retry: Option[FiniteDuration], attempts: Int)
private case class MaxRetriesReached(underlying: Option[Throwable])
extends Exception("Reached maximum number of retries", underlying.orNull)
private val BaseHeaders: Headers = Headers.of(
`Cache-Control`(CacheDirective.`no-cache`()),
Accept(MediaType.`text/event-stream`)
)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment