Skip to content

Instantly share code, notes, and snippets.

@igor-ramazanov
Created May 19, 2022 13:57
Show Gist options
  • Save igor-ramazanov/9e2217a35c36b4c8b1c021eaf8d7143a to your computer and use it in GitHub Desktop.
Save igor-ramazanov/9e2217a35c36b4c8b1c021eaf8d7143a to your computer and use it in GitHub Desktop.
Toy IO/Fiber system
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.Executors
import java.util.{Timer, TimerTask}
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Promise}
import scala.util.chaining._
import scala.util.{Failure, Success, Try}
class Context(private val timer: Timer, private val threadPool: ExecutionContext) {
def schedule(task: => Unit, fd: FiniteDuration): Unit = {
val timerTask = new TimerTask {
def run(): Unit = sendToPool(task)
}
timer.schedule(timerTask, fd.toMillis)
}
def sendToPool[A](task: => A): Unit = threadPool.execute(() => { val _ = task })
}
sealed trait IO[A] extends Product with Serializable {
/** Implemented in terms of [[runAsync]] */
def runSync()(implicit context: Context): Try[A] = {
val p = Promise[A]()
this.runAsync(p.complete)
Await.ready(p.future, Duration.Inf).value.get
}
/** @param callback to handle a result of a computation, guaranteed to be invoked only once. */
def runAsync(callback: Try[A] => Unit)(implicit context: Context): Unit =
IO.runLoop(this)(callback.asInstanceOf[Try[_] => Unit])
def map[B](f: A => B): IO[B] = IO.Map(this, f)
def as[B](b: B): IO[B] = IO.Map(this, (_: A) => b)
def flatMap[B](f: A => IO[B]): IO[B] = IO.FlatMap(this, f)
def >>[B](f: IO[B]): IO[B] = IO.FlatMap(this, (_: A) => f)
def recover(f: PartialFunction[Throwable, A]): IO[A] = IO.Recover(this, f)
def recoverWith(f: PartialFunction[Throwable, IO[A]]): IO[A] = IO.RecoverWith(this, f)
def fork(): IO[IO.Fiber[A]] = IO.Fork(this)
}
object IO {
def apply[A](thunk: => A): IO[A] = IO.Delay(() => thunk)
def pure[A](a: A): IO[A] = IO.Pure(a)
def async[A](cb: (Try[A] => Unit) => Unit): IO[A] = IO.Async(cb)
def sleep(duration: FiniteDuration): IO[Unit] = IO.Sleep(duration)
def raiseError(e: Throwable): IO[Unit] = IO.Error(e)
def putStrLn(s: String): IO[Unit] = IO(println(s"${Thread.currentThread().getName}: $s"))
def shift: IO[Unit] = IO.Shift
final private case class Pure[A](value: A) extends IO[A]
final private case class Delay[A](thunk: () => A) extends IO[A]
final private case class Async[A](callback: (Try[A] => Unit) => Unit) extends IO[A]
final private case class FlatMap[A, B](prev: IO[A], f: A => IO[B]) extends IO[B]
final private case class Map[A, B](prev: IO[A], f: A => B) extends IO[B]
final private case class Recover[A](prev: IO[A], f: PartialFunction[Throwable, A]) extends IO[A]
final private case class RecoverWith[A](prev: IO[A], f: PartialFunction[Throwable, IO[A]]) extends IO[A]
final private case class Error(e: Throwable) extends IO[Unit]
final private case class Sleep[A](duration: FiniteDuration) extends IO[A]
final private case class Fork[A](io: IO[A]) extends IO[Fiber[A]]
final private case class Join[A](fiber: Fiber[A]) extends IO[A]
final private case object Shift extends IO[Unit]
class Fiber[A] {
private var callbacks = Set.empty[Try[A] => Unit]
private var result = Option.empty[Try[A]]
def join(): IO[A] = IO.Join(this)
private[IO] def register(cb: Try[A] => Unit): Unit = {
synchronized {
result match {
// To ensure the callback invoked only once.
case Some(value) => cb(value)
case None => callbacks = callbacks + cb
}
}
}
private[IO] def finish(res: Try[A]): Unit = {
synchronized {
// To ensure the callback invoked only once.
result = Some(res)
callbacks.foreach(_(res))
callbacks = Set.empty
}
}
}
/** Optimised for maximum throughput, fairness must be ensured by the end developer by using [[IO.shift]]. */
private def runLoop(io: IO[_])(done: Try[_] => Unit)(implicit context: Context): Unit =
// Evaluation should run in an intended thread pool since the beginning.
// Otherwise, the first computations would run in a default 'main' JVM thread.
context.sendToPool(eval(io)(done))
private def eval(io: IO[_])(done: Try[_] => Unit)(implicit context: Context): Unit =
io match {
case IO.Pure(value) => done(Success(value))
case IO.Delay(thunk) => done(Success(thunk()))
case IO.Async(asyncTaskDefinition) => asyncTaskDefinition(done)
case IO.FlatMap(prev, f) =>
eval(prev) {
case Success(value) => eval(f.asInstanceOf[Any => IO[_]](value))(done)
case x => done(x)
}
case IO.Map(prev, f) => eval(prev)(res => done(res.map(f.asInstanceOf[Any => Any])))
case IO.Recover(prev, f) => eval(prev)(res => done(res.recover(f)))
case IO.RecoverWith(prev, f) =>
eval(prev) {
case Failure(e) if f.isDefinedAt(e) => eval(f(e))(done)
case x => done(x)
}
case IO.Error(e) => done(Failure(e))
case IO.Sleep(duration) => context.schedule(done(Success(())), duration)
case _: IO.Fork[_] =>
val fiber = new Fiber[Any] {}
context.sendToPool(eval(io.asInstanceOf[IO.Fork[_]].io)(fiber.finish))
done(Success(fiber))
case IO.Join(fiber) => fiber.register(done)
case IO.Shift => context.sendToPool(eval(io)(done))
}
}
object Main {
def main(args: Array[String]): Unit = {
val program: IO[Unit] = (for {
fiber <- (IO.sleep(1.second) >> IO.putStrLn("1") >> IO.sleep(3.second) >> IO.putStrLn("2") >> IO.pure(42)).fork()
_ <- IO.putStrLn("3")
value <- fiber.join()
value2 <- fiber.join()
_ <- IO.raiseError(new RuntimeException(s"Boom! $value $value2"))
} yield ()).recoverWith {
case e => IO.putStrLn(e.getMessage)
}
implicit val context: Context = {
val timer: Timer = new Timer("Pet IO Timer", true)
val counter = new AtomicInteger(0)
val nThreads = 1
val pool: ExecutionContextExecutor = ExecutionContext.fromExecutor(
Executors.newFixedThreadPool(
nThreads,
(r: Runnable) => new Thread(r).tap(_.setDaemon(true)).tap(_.setName(s"Pet IO ${counter.getAndIncrement()}"))
)
)
new Context(timer, pool)
}
println(program.runSync())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment