Skip to content

Instantly share code, notes, and snippets.

@LMnet
Last active September 29, 2022 14:07
Show Gist options
  • Save LMnet/7c8adc5ad8b5dbf787c60f42ef16ac5b to your computer and use it in GitHub Desktop.
Save LMnet/7c8adc5ad8b5dbf787c60f42ef16ac5b to your computer and use it in GitHub Desktop.
StreamUtils with parJoinPrioritized
package fs2
import cats.Applicative
import cats.effect.kernel.{Fiber, Outcome}
import cats.effect.syntax.all._
import cats.effect.{Concurrent, Deferred}
import cats.syntax.all._
import fs2.concurrent.{Channel, SignallingRef}
import fs2.internal.Scope
import tmp.Queue
object StreamUtils {
/**
* Nondeterministically merges streams to a single stream.
* Order of the streams is important: left streams will have more priority.
* If there is some data in the leftmost stream, the resulting stream will produce this data.
* When leftmost stream hangs, resulting stream will try to get data from the next stream and so on.
*
* Every stream has a buffer for messages chunks,
* to prefetch some data from all input streams before prioritized getting data.
*/
def parJoinPrioritized[F[_]: Concurrent, A](
streamsWithBuffers: (fs2.Stream[F, A], Int)*
): fs2.Stream[F, A] = {
val F = Concurrent[F]
Stream.force(for {
streamsDone <- SignallingRef(None: Option[Option[Throwable]])
outerDone <- Deferred[F, Either[Throwable, Unit]]
running <- SignallingRef(0)
outcomes <- Channel.unbounded[F, F[Unit]]
output <- Channel.synchronous[F, Chunk[A]]
} yield {
// stops the join evaluation
// all the streams will be terminated. If err is supplied, that will get attached to any error currently present
def stop(rslt: Option[Throwable]): F[Unit] = {
streamsDone.update {
case rslt0 @ Some(Some(err0)) =>
rslt.fold[Option[Option[Throwable]]](rslt0) { err =>
Some(Some(CompositeFailure(err0, err)))
}
case _ => Some(rslt)
}
}
val incrementRunning: F[Unit] = {
running.update(_ + 1)
}
def decrementRunning: F[Unit] = {
running
.updateAndGet(_ - 1)
.flatMap { now =>
if (now == 0) outcomes.close.void else F.unit
}
}
def onOutcome(
oc: Outcome[F, Throwable, Unit],
cancelResult: Either[Throwable, Unit]
): F[Unit] = {
oc match {
case Outcome.Succeeded(fu) =>
cancelResult.fold(t => stop(Some(t)), _ => outcomes.send(fu).void)
case Outcome.Errored(t) =>
CompositeFailure
.fromResults(Left(t), cancelResult)
.fold(t => stop(Some(t)), _ => F.unit)
case Outcome.Canceled() =>
cancelResult.fold(t => stop(Some(t)), _ => F.unit)
}
}
def outcomeJoiner: F[Unit] =
outcomes.stream
.evalMap(identity)
.compile
.drain
.guaranteeCase {
case Outcome.Succeeded(_) =>
stop(None)
case Outcome.Errored(t) =>
stop(Some(t))
case Outcome.Canceled() =>
stop(None)
}
.handleError(_ => ())
// runs inner stream
// each stream is forked.
// terminates when killSignal is true
// if fails will enq in queue failure
// note that supplied scope's resources must be leased before the inner stream forks the execution to another thread
// and that it must be released once the inner stream terminates or fails.
def runInner(inner: Stream[F, A], outerScope: Scope[F], buffer: Queue[F, Chunk[A]]): F[Unit] = {
F.uncancelable { _ =>
outerScope.lease
.flatTap(_ => incrementRunning)
.flatMap { lease =>
F.start {
inner
.chunks
.evalMap(buffer.offer)
.interruptWhen(streamsDone.map(_.nonEmpty)) // must be AFTER enqueue to the sync queue, otherwise the process may hang to enq last item while being interrupted
.compile
.drain
.guaranteeCase(oc =>
lease.cancel
.flatMap(onOutcome(oc, _)) >> decrementRunning
)
.handleError(_ => ())
}.void
}
}
}
// awaits when all streams (outer + inner) finished,
// and then collects result of the stream (outer + inner) execution
def signalResult(fiber: Fiber[F, Throwable, Unit]): F[Unit] = {
streamsDone.get.flatMap { res: Option[Option[Throwable]] =>
res.flatten.fold[F[Unit]](fiber.joinWithNever)(F.raiseError)
}
}
// creating an Queue as a buffer for each input stream
def buffers(outerScope: Scope[F]) = Stream.eval {
streamsWithBuffers.toList.traverse { case (stream, bufferSize) =>
Queue.bounded[F, Chunk[A]](bufferSize).flatMap { streamBuffer =>
runInner(stream, outerScope, streamBuffer).as(streamBuffer)
}
}
}
//starting a loop for retrieving data
def loop(outerScope: Scope[F]): Stream[F, Chunk[A]] = {
buffers(outerScope).flatMap { streamBuffers: List[Queue[F, Chunk[A]]] =>
// trying to get a next chunk from the buffers, starting with the topmost
def getNext: F[Option[Chunk[A]]] = {
streamBuffers.collectFirstSomeM { buffer =>
buffer.tryTake
}
}
// Looping getNext if there is enough data. If there is no data in buffers —
// waiting for a new portion of data with `peek` and after that looping again.
def waitForNext: F[Chunk[A]] = {
getNext.flatMap {
case Some(elem) =>
Applicative[F].pure(elem)
case None =>
Deferred[F, Unit].flatMap { nextElemWaiter =>
streamBuffers.traverse { buffer =>
F.start(buffer.peek.flatMap { _ =>
nextElemWaiter.complete(()).void
})
}.uncancelable.flatMap { fibers =>
nextElemWaiter.get.guarantee(fibers.traverse_(_.cancel)) >> waitForNext
}
}
}
}
// if stream is in finalizing process, we should return all data from the buffers,
// and only after terminate the stream
def step: F[Option[Chunk[A]]] = F.uncancelable { _ =>
streamsDone.get.flatMap {
case None =>
// interrupt waitForNext when all streams are finished
F.race(streamsDone.waitUntil(_.nonEmpty), waitForNext)
.map(_.toOption)
case Some(_) =>
getNext
}
}
Stream.unfoldEval(()) { _ =>
step.map(_.map((_, ())))
}
}
}
val runOuter = {
F.uncancelable { _ =>
Pull
.getScope[F]
.flatMap(Pull.output1)
.stream
.flatMap { outerScope =>
loop(outerScope)
}
.evalMap { chunk =>
output.send(chunk).void
}
.interruptWhen(outerDone)
.compile
.drain
.guaranteeCase { x =>
onOutcome(x, Right(())) >> output.close.void
}
.handleError(_ => ())
}
}
Stream
.bracket {
for {
fiberOuter <- F.start(runOuter)
fiberOutcome <- F.start(outcomeJoiner)
} yield (fiberOuter, fiberOutcome)
} { case (fiberOuter, fiberOutcome) =>
F.uncancelable { _ =>
for {
_ <- stop(None)
// in case of short-circuiting, the `fiberJoiner` would not have had a chance
// to wait until all fibers have been joined, so we need to do it manually
// by waiting on the counter
_ <- running.waitUntil(_ == 0)
_ <- signalResult(fiberOutcome)
_ <- outerDone.complete(Right(()))
_ <- fiberOuter.joinWithNever
} yield ()
}
}
.flatMap { _ =>
output.stream.flatMap(Stream.chunk(_).covary[F])
}
})
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment