Skip to content

Instantly share code, notes, and snippets.

@lrodero
Last active July 19, 2023 19:57
Show Gist options
  • Save lrodero/552f59a7157987c7ed69dba3b9404f37 to your computer and use it in GitHub Desktop.
Save lrodero/552f59a7157987c7ed69dba3b9404f37 to your computer and use it in GitHub Desktop.
//> using scala "2.13.11"
//> using lib "org.typelevel::cats-effect::3.5.1"
import cats.effect.{Async, Deferred, ExitCode, IO, IOApp, Ref, Sync}
import cats.effect.std.Console
import cats.instances.list._
import cats.syntax.all._
import java.util.concurrent.ScheduledThreadPoolExecutor
import scala.collection.immutable.Queue
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.DurationInt
/**
* Multiple producer - multiple consumer system using an unbounded concurrent queue.
*
* Second part of cats-effect tutorial at https://typelevel.org/cats-effect/tutorial/tutorial.html
*
* Code to _offer_ and _take_ elements to/from queue is taken from CE3's Queue implementation.
*/
object ProducerConsumer extends IOApp {
case class State[F[_], A](queue: Queue[A], takers: Queue[Deferred[F,A]])
object State {
def empty[F[_], A]: State[F, A] = State(Queue.empty, Queue.empty)
}
def producer[F[_]: Async: Console](id: Int, counterR: Ref[F, Int], stateR: Ref[F, State[F,Int]]): F[Unit] = {
def offer(i: Int): F[Unit] =
stateR.modify {
case State(queue, takers) if takers.nonEmpty =>
val (taker, rest) = takers.dequeue
State(queue, rest) -> taker.complete(i).void
case State(queue, takers) =>
State(queue.enqueue(i), takers) -> Sync[F].unit
}.flatten
for {
i <- counterR.getAndUpdate(_ + 1)
_ <- offer(i)
_ <- if(i % 10000 == 0) Console[F].println(s"Producer $id has reached $i items") else Sync[F].unit
_ <- Async[F].cede
_ <- producer(id, counterR, stateR)
} yield ()
}
def consumer[F[_]: Async: Console](id: Int, stateR: Ref[F, State[F, Int]]): F[Unit] = {
val take: F[Int] =
Deferred[F, Int].flatMap { taker =>
stateR.modify {
case State(queue, takers) if queue.nonEmpty =>
val (i, rest) = queue.dequeue
State(rest, takers) -> Async[F].pure(i)
case State(queue, takers) =>
State(queue, takers.enqueue(taker)) -> taker.get
}.flatten
}
for {
i <- take
_ <- if(i % 10000 == 0) Console[F].println(s"Consumer $id has reached $i items") else Async[F].unit
_ <- Async[F].cede
_ <- consumer(id, stateR)
} yield ()
}
override def run(args: List[String]): IO[ExitCode] =
for {
stateR <- Ref.of[IO, State[IO,Int]](State.empty[IO, Int])
counterR <- Ref.of[IO, Int](1)
producers = List.range(1, 11).map(producer(_, counterR, stateR)) // 10 producers
consumers = List.range(1, 11).map(consumer(_, stateR)) // 10 consumers
procs <- IO(Runtime.getRuntime.availableProcessors())
tpe = new ScheduledThreadPoolExecutor(procs / 2)
executionContext = ExecutionContext.fromExecutor(tpe)
computation = (producers ++ consumers)
.parSequence.as(ExitCode.Success) // Run producers and consumers in parallel until done (likely by user cancelling with CTRL-C)
.handleErrorWith { t =>
Console[IO].errorln(s"Error caught: ${t.getMessage}").as(ExitCode.Error)
}
res <- Async[IO].evalOn(computation, executionContext)
} yield res
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment