Skip to content

Instantly share code, notes, and snippets.

@jmcardon
Last active August 3, 2018 21:08
Show Gist options
  • Save jmcardon/a69d3966d0b4b96e3a9b9c4bb40d8480 to your computer and use it in GitHub Desktop.
Save jmcardon/a69d3966d0b4b96e3a9b9c4bb40d8480 to your computer and use it in GitHub Desktop.
import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import cats.effect.{Concurrent, Sync}
import cats.syntax.all._
import fs2.async.mutable.Queue
import fs2.{Sink, Stream}
import org.http4s.websocket.WebsocketBits._
import WebsocketMsg._
import cats.Monad
/** A simplified websocket message ADT
*
*/
sealed trait WebsocketMsg {
def toFrame: WebSocketFrame
}
final case class TextMsg(content: String) extends WebsocketMsg {
def toFrame: WebSocketFrame = Text(content)
}
final case class BinaryMsg(content: Array[Byte]) extends WebsocketMsg {
def toFrame: WebSocketFrame = Binary(content)
}
object WebsocketMsg {
sealed trait State
case object BufferingText extends State
case object BufferingBinary extends State
case object Empty extends State
}
abstract class FSMAlgebra[F[_]] {
def getState: F[State]
def clearState(): F[Unit]
def lastText(content: Array[Byte]): F[WebsocketMsg]
def lastBinary(content: Array[Byte]): F[WebsocketMsg]
def fragmentedBinary(content: Array[Byte]): F[Unit]
def fragmentedText(content: String): F[Unit]
def enqueueAll(w: Stream[F, WebsocketMsg]): F[Unit]
def out: Stream[F, WebSocketFrame]
}
object FSMAlgebra {
def apply[F[_]](implicit F: Concurrent[F]): F[FSMAlgebra[F]] =
for {
msgQueue <- Queue.unbounded[F, WebSocketFrame]
stateRef <- F.delay(new AtomicReference[State](Empty))
len <- F.delay(new AtomicInteger(0))
i <- F.delay(new Impl[F](msgQueue, stateRef, len))
} yield i
private class Impl[F[_]](
msgQueue: Queue[F, WebSocketFrame],
state: AtomicReference[State],
msgLen: AtomicInteger,
)(implicit F: Sync[F])
extends FSMAlgebra[F] {
@volatile private[this] var internalList: List[Array[Byte]] = Nil
private[this] def foldBytesToArray(lastBytes: Array[Byte]): F[Array[Byte]] =
F.delay(msgLen.get()).flatMap { len =>
F.delay {
val aggregator = new ReverseByteArrayAggregator(len + lastBytes.length)
aggregator.aggregate(lastBytes)
while (internalList.nonEmpty) {
aggregator.aggregate(internalList.head)
internalList = internalList.tail
}
aggregator.emit
}
}
private[this] def compareAndSetState(old: State, ns: State): F[Boolean] =
F.delay(state.compareAndSet(old, ns))
private[this] def clearMsgLen(): F[Unit] =
F.delay(msgLen.set(0))
private[this] def incrementMsgLen(i: Int) =
F.delay({ msgLen.getAndAdd(i); () })
def getState: F[State] = F.delay(state.get())
def clearState(): F[Unit] = F.delay(state.set(Empty))
def lastText(content: Array[Byte]): F[WebsocketMsg] =
for {
bytes <- foldBytesToArray(content)
_ <- clearMsgLen()
} yield TextMsg(new String(bytes, UTF_8))
def lastBinary(content: Array[Byte]): F[WebsocketMsg] =
for {
bytes <- foldBytesToArray(content)
_ <- clearMsgLen()
} yield BinaryMsg(bytes)
def fragmentedBinary(content: Array[Byte]): F[Unit] =
for {
_ <- compareAndSetState(Empty, BufferingBinary)
_ <- incrementMsgLen(content.length)
_ <- F.delay(internalList = content::internalList)
} yield ()
def fragmentedText(content: String): F[Unit] = {
val bytes = content.getBytes(UTF_8)
for {
_ <- compareAndSetState(Empty, BufferingText)
_ <- incrementMsgLen(bytes.length)
_ <- F.delay(internalList = bytes::internalList)
} yield ()
}
def enqueueAll(w: Stream[F, WebsocketMsg]): F[Unit] =
w.map(_.toFrame).through(msgQueue.enqueue).compile.drain
def out: Stream[F, WebSocketFrame] = msgQueue.dequeue
}
private class ReverseByteArrayAggregator(size: Int) {
require(size > 0)
private[this] val internal = new Array[Byte](size)
private[this] var nextIx: Int = size
def aggregate(arr: Array[Byte]): ReverseByteArrayAggregator = {
nextIx -= arr.length
if (nextIx < 0)
throw new ArrayIndexOutOfBoundsException("Size will exceed append size")
else {
System.arraycopy(arr, 0, internal, nextIx, arr.length)
this
}
}
def emit: Array[Byte] = internal
}
}
final class WSFSM[F[_]](f: WebsocketMsg => Stream[F, WebsocketMsg],
alg: FSMAlgebra[F])(implicit F: Monad[F]) {
def handleText(content: String, last: Boolean): F[Unit] =
if (last) {
for {
st <- alg.getState
_ <- alg.clearState()
msg <- st match {
case BufferingBinary =>
alg.lastBinary(content.getBytes(UTF_8))
case BufferingText =>
alg.lastText(content.getBytes(UTF_8))
case Empty =>
F.pure[WebsocketMsg](TextMsg(content))
}
_ <- alg.enqueueAll(f(msg))
} yield ()
} else alg.fragmentedText(content)
def handleBinary(content: Array[Byte], last: Boolean): F[Unit] =
if (last) {
for {
st <- alg.getState
_ <- alg.clearState()
msg <- st match {
case BufferingBinary =>
alg.lastBinary(content)
case BufferingText =>
alg.lastText(content)
case Empty =>
F.pure[WebsocketMsg](BinaryMsg(content))
}
_ <- alg.enqueueAll(f(msg))
} yield ()
} else alg.fragmentedBinary(content)
def send: Stream[F, WebSocketFrame] = alg.out
def recv: Sink[F, WebSocketFrame] = _.evalMap {
case Text(content, last) =>
handleText(content, last)
case Binary(content, last) =>
handleBinary(content, last)
case Continuation(content, last) =>
handleBinary(content, last)
case _ =>
F.unit //Do not worry about handling other messages
}
}
object WSFSM {
def apply[F[_]: Concurrent](f: WebsocketMsg => Stream[F, WebsocketMsg]): F[WSFSM[F]] =
FSMAlgebra[F].map(new WSFSM[F](f, _))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment