Skip to content

Instantly share code, notes, and snippets.

@sergey-scherbina
Last active February 7, 2019 13:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sergey-scherbina/28f06c26e29128b747433e7aaa49bf63 to your computer and use it in GitHub Desktop.
Save sergey-scherbina/28f06c26e29128b747433e7aaa49bf63 to your computer and use it in GitHub Desktop.
import cats._
import cats.implicits._
import scala.annotation.tailrec
import scala.util.control.TailCalls._
/*
http://okmij.org/ftp/continuations/zipper.html
http://okmij.org/ftp/Haskell/ZipperTraversable.hs
*/
object Cont {
type ^[A] = TailRec[A]
type =>>[A, B] = A => ^[B]
type <<=[R, A] = ^[A =>> R =>> R]
def ^[R, A](a: A): R <<= A = done(_ (a))
def shift[R, A](e: A =>> R =>> R): R <<= A = done(e(_))
def reset[R](k: R <<= R): R = k.flatMap(_ (done)).result
implicit def monad[R] = new StackSafeMonad[R <<= ?] {
override final def pure[A](a: A): R <<= A = ^(a)
override final def flatMap[A, B](fa: R <<= A)(f: A => R <<= B): R <<= B =
done(k => fa.flatMap(_ (f(_).flatMap(_ (k)))))
}
}
object Zipper {
import Cont._
def apply[F[_] : Traverse, A](fa: F[A]): Zipper[F, A] =
reset(fa traverse suspend[F, A] >>= (r => ^(ZDone(r))))
def suspend[F[_], A](a: A): Zipper[F, A] <<= A = shift(k => done(ZNext(a, k(_))))
final case class ZDone[F[_], A](fa: F[A]) extends Zipper[F, A]
final case class ZNext[F[_], A](a: A, k: A =>> Zipper[F, A]) extends Zipper[F, A]
}
sealed trait Zipper[F[_], A] {
import Zipper._
def get: Either[A, F[A]] = this match {
case ZNext(a, _) => Left(a)
case ZDone(fa) => Right(fa)
}
def next(f: A => A = identity): Zipper[F, A] = this match {
case ZNext(a, k) => k(f(a)).result
case _ => this
}
def skip(n: Int): Zipper[F, A] = {
@tailrec def loop(z: Zipper[F, A], n: Int): Zipper[F, A] =
if (n <= 0) z else loop(z.next(), n - 1)
loop(this, n)
}
}
object ContTest extends App {
val N = 1000000
val z = Zipper(Stream.iterate(1)(_ + 1).take(N))
println(z.get)
val z1 = z.next()
println(z1.get)
val z2 = z1.next(_ + 20)
println(z2.get)
val z3 = z2.skip(N - 3)
println(z3.get)
// No stack overflow here:
println(z3.next().get)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment