Skip to content

Instantly share code, notes, and snippets.

@monadplus
Created March 20, 2019 08:41
Show Gist options
  • Save monadplus/90c56dc235717e68eefe30c626c3f55b to your computer and use it in GitHub Desktop.
Save monadplus/90c56dc235717e68eefe30c626c3f55b to your computer and use it in GitHub Desktop.
Exercise: implement cats.Deferred from scratches
// Original code from https://github.com/typelevel/cats-effect/blob/1846813109b1e78c5bf36e6e179d7a91419e01d0/core/shared/src/main/scala/cats/effect/concurrent/Deferred.scala#L164
// Exercise: implement deferred
// Bear in mind:
// - always use compareAndSet when updating the atomic reference
// - don't break RT when dealing with ref (i.e. F.delay, F.suspend)
// - `get` should be cancelable.
// - `complete` should return a F[Unit] that does not block the current ec.
// - `register` and `unregister` methods should be tail-recursive.
// Helpers
private final class Id
private sealed abstract class State[A]
private object State {
final case class Set[A](a: A) extends State[A]
// a => cb(Right(a))
final case class Unset[A](waiting: LinkedMap[Id, A => Unit]) extends State[A]
}
private final class ConcurrentDeferred[F[_], A](ref: AtomicReference[State[A]])(implicit F: Concurrent[F])
extends TryableDeferred[F, A] {
// Return set value or wait for the deferred to be completed
// Hint: use F.cancelable and register cb in ref
override def get: F[A] = ???
// Don't wait, return None if deferred is not completed
override def tryGet: F[Option[A]] = ???
// throw if already completed, otherwise complete all waiting tasks
override def complete(a: A): F[Unit] = ???
}
// Solution:
private final class ConcurrentDeferred[F[_], A](ref: AtomicReference[State[A]])(implicit F: Concurrent[F])
extends TryableDeferred[F, A] {
override def get: F[A] =
F.suspend {
ref.get() match {
case State.Set(a) =>
F.pure(a)
case State.Unset(_) =>
F.cancelable { cb =>
val id = unsafeRegister(cb)
@tailrec
def unregister(): Unit =
ref.get() match {
case State.Set(_) => ()
case s @ State.Unset(waiting) =>
val updated = State.Unset(waiting - id)
if (ref.compareAndSet(s, updated)) ()
else unregister()
}
F.delay(unregister())
}
}
}
override def tryGet: F[Option[A]] =
F.delay {
ref.get() match {
case State.Set(a) =>
Some(a)
case State.Unset(_) =>
None
}
}
override def complete(a: A): F[Unit] =
F.suspend(unsafeComplete(a))
@tailrec
private[this] def unsafeComplete(a: A): F[Unit] =
ref.get() match {
case State.Set(_) =>
throw new RuntimeException("Attempting to complete a Deferred that has already been completed")
case s @ State.Unset(waiting) =>
if (ref.compareAndSet(s, State.Set(a))) {
val w = waiting.values
if (w.isEmpty) F.unit
else notifyReadersLoop(a, w)
} else {
unsafeComplete(a)
}
}
private[this] def notifyReadersLoop(a: A, iterable: Iterable[A => Unit]): F[Unit] = {
var acc: F[Unit] = F.unit
val it = iterable.toIterator
while(it.hasNext) {
val f = it.next()
val task = mapUnit(F.start(F.delay(f(a))))
acc = F.flatMap(acc)(_ => task)
}
acc
}
private[this] def mapUnit[B](fb: F[B]): F[Unit] = F.map(fb)(_ => ())
private[this] def unsafeRegister(cb: Either[Throwable, A] => Unit): Id = {
val id = new Id
@tailrec
def register(): Option[A] = {
ref.get() match {
case State.Set(a) => Some(a)
case w @ State.Unset(waiting) =>
val updated = State.Unset(waiting.updated(id, (a: A)=> cb(Right(a))))
if (ref.compareAndSet(w, updated)) None
else register()
}
}
register().foreach(a => cb(Right(a)))
id
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment