Last active
September 29, 2022 14:07
-
-
Save LMnet/7c8adc5ad8b5dbf787c60f42ef16ac5b to your computer and use it in GitHub Desktop.
StreamUtils with parJoinPrioritized
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package fs2 | |
import cats.Applicative | |
import cats.effect.kernel.{Fiber, Outcome} | |
import cats.effect.syntax.all._ | |
import cats.effect.{Concurrent, Deferred} | |
import cats.syntax.all._ | |
import fs2.concurrent.{Channel, SignallingRef} | |
import fs2.internal.Scope | |
import tmp.Queue | |
object StreamUtils { | |
/** | |
* Nondeterministically merges streams to a single stream. | |
* Order of the streams is important: left streams will have more priority. | |
* If there is some data in the leftmost stream, the resulting stream will produce this data. | |
* When leftmost stream hangs, resulting stream will try to get data from the next stream and so on. | |
* | |
* Every stream has a buffer for messages chunks, | |
* to prefetch some data from all input streams before prioritized getting data. | |
*/ | |
def parJoinPrioritized[F[_]: Concurrent, A]( | |
streamsWithBuffers: (fs2.Stream[F, A], Int)* | |
): fs2.Stream[F, A] = { | |
val F = Concurrent[F] | |
Stream.force(for { | |
streamsDone <- SignallingRef(None: Option[Option[Throwable]]) | |
outerDone <- Deferred[F, Either[Throwable, Unit]] | |
running <- SignallingRef(0) | |
outcomes <- Channel.unbounded[F, F[Unit]] | |
output <- Channel.synchronous[F, Chunk[A]] | |
} yield { | |
// stops the join evaluation | |
// all the streams will be terminated. If err is supplied, that will get attached to any error currently present | |
def stop(rslt: Option[Throwable]): F[Unit] = { | |
streamsDone.update { | |
case rslt0 @ Some(Some(err0)) => | |
rslt.fold[Option[Option[Throwable]]](rslt0) { err => | |
Some(Some(CompositeFailure(err0, err))) | |
} | |
case _ => Some(rslt) | |
} | |
} | |
val incrementRunning: F[Unit] = { | |
running.update(_ + 1) | |
} | |
def decrementRunning: F[Unit] = { | |
running | |
.updateAndGet(_ - 1) | |
.flatMap { now => | |
if (now == 0) outcomes.close.void else F.unit | |
} | |
} | |
def onOutcome( | |
oc: Outcome[F, Throwable, Unit], | |
cancelResult: Either[Throwable, Unit] | |
): F[Unit] = { | |
oc match { | |
case Outcome.Succeeded(fu) => | |
cancelResult.fold(t => stop(Some(t)), _ => outcomes.send(fu).void) | |
case Outcome.Errored(t) => | |
CompositeFailure | |
.fromResults(Left(t), cancelResult) | |
.fold(t => stop(Some(t)), _ => F.unit) | |
case Outcome.Canceled() => | |
cancelResult.fold(t => stop(Some(t)), _ => F.unit) | |
} | |
} | |
def outcomeJoiner: F[Unit] = | |
outcomes.stream | |
.evalMap(identity) | |
.compile | |
.drain | |
.guaranteeCase { | |
case Outcome.Succeeded(_) => | |
stop(None) | |
case Outcome.Errored(t) => | |
stop(Some(t)) | |
case Outcome.Canceled() => | |
stop(None) | |
} | |
.handleError(_ => ()) | |
// runs inner stream | |
// each stream is forked. | |
// terminates when killSignal is true | |
// if fails will enq in queue failure | |
// note that supplied scope's resources must be leased before the inner stream forks the execution to another thread | |
// and that it must be released once the inner stream terminates or fails. | |
def runInner(inner: Stream[F, A], outerScope: Scope[F], buffer: Queue[F, Chunk[A]]): F[Unit] = { | |
F.uncancelable { _ => | |
outerScope.lease | |
.flatTap(_ => incrementRunning) | |
.flatMap { lease => | |
F.start { | |
inner | |
.chunks | |
.evalMap(buffer.offer) | |
.interruptWhen(streamsDone.map(_.nonEmpty)) // must be AFTER enqueue to the sync queue, otherwise the process may hang to enq last item while being interrupted | |
.compile | |
.drain | |
.guaranteeCase(oc => | |
lease.cancel | |
.flatMap(onOutcome(oc, _)) >> decrementRunning | |
) | |
.handleError(_ => ()) | |
}.void | |
} | |
} | |
} | |
// awaits when all streams (outer + inner) finished, | |
// and then collects result of the stream (outer + inner) execution | |
def signalResult(fiber: Fiber[F, Throwable, Unit]): F[Unit] = { | |
streamsDone.get.flatMap { res: Option[Option[Throwable]] => | |
res.flatten.fold[F[Unit]](fiber.joinWithNever)(F.raiseError) | |
} | |
} | |
// creating an Queue as a buffer for each input stream | |
def buffers(outerScope: Scope[F]) = Stream.eval { | |
streamsWithBuffers.toList.traverse { case (stream, bufferSize) => | |
Queue.bounded[F, Chunk[A]](bufferSize).flatMap { streamBuffer => | |
runInner(stream, outerScope, streamBuffer).as(streamBuffer) | |
} | |
} | |
} | |
//starting a loop for retrieving data | |
def loop(outerScope: Scope[F]): Stream[F, Chunk[A]] = { | |
buffers(outerScope).flatMap { streamBuffers: List[Queue[F, Chunk[A]]] => | |
// trying to get a next chunk from the buffers, starting with the topmost | |
def getNext: F[Option[Chunk[A]]] = { | |
streamBuffers.collectFirstSomeM { buffer => | |
buffer.tryTake | |
} | |
} | |
// Looping getNext if there is enough data. If there is no data in buffers — | |
// waiting for a new portion of data with `peek` and after that looping again. | |
def waitForNext: F[Chunk[A]] = { | |
getNext.flatMap { | |
case Some(elem) => | |
Applicative[F].pure(elem) | |
case None => | |
Deferred[F, Unit].flatMap { nextElemWaiter => | |
streamBuffers.traverse { buffer => | |
F.start(buffer.peek.flatMap { _ => | |
nextElemWaiter.complete(()).void | |
}) | |
}.uncancelable.flatMap { fibers => | |
nextElemWaiter.get.guarantee(fibers.traverse_(_.cancel)) >> waitForNext | |
} | |
} | |
} | |
} | |
// if stream is in finalizing process, we should return all data from the buffers, | |
// and only after terminate the stream | |
def step: F[Option[Chunk[A]]] = F.uncancelable { _ => | |
streamsDone.get.flatMap { | |
case None => | |
// interrupt waitForNext when all streams are finished | |
F.race(streamsDone.waitUntil(_.nonEmpty), waitForNext) | |
.map(_.toOption) | |
case Some(_) => | |
getNext | |
} | |
} | |
Stream.unfoldEval(()) { _ => | |
step.map(_.map((_, ()))) | |
} | |
} | |
} | |
val runOuter = { | |
F.uncancelable { _ => | |
Pull | |
.getScope[F] | |
.flatMap(Pull.output1) | |
.stream | |
.flatMap { outerScope => | |
loop(outerScope) | |
} | |
.evalMap { chunk => | |
output.send(chunk).void | |
} | |
.interruptWhen(outerDone) | |
.compile | |
.drain | |
.guaranteeCase { x => | |
onOutcome(x, Right(())) >> output.close.void | |
} | |
.handleError(_ => ()) | |
} | |
} | |
Stream | |
.bracket { | |
for { | |
fiberOuter <- F.start(runOuter) | |
fiberOutcome <- F.start(outcomeJoiner) | |
} yield (fiberOuter, fiberOutcome) | |
} { case (fiberOuter, fiberOutcome) => | |
F.uncancelable { _ => | |
for { | |
_ <- stop(None) | |
// in case of short-circuiting, the `fiberJoiner` would not have had a chance | |
// to wait until all fibers have been joined, so we need to do it manually | |
// by waiting on the counter | |
_ <- running.waitUntil(_ == 0) | |
_ <- signalResult(fiberOutcome) | |
_ <- outerDone.complete(Right(())) | |
_ <- fiberOuter.joinWithNever | |
} yield () | |
} | |
} | |
.flatMap { _ => | |
output.stream.flatMap(Stream.chunk(_).covary[F]) | |
} | |
}) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment