Last active
March 1, 2023 16:05
-
-
Save zliu41/4d3e5a440339181bfdef57af743c0a1d to your computer and use it in GitHub Desktop.
Factorial computation with recursion schemes in Scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cats.Functor | |
import cats.implicits._ | |
import scala.language.higherKinds | |
sealed trait StackR | |
final case class DoneR(result: Int = 1) extends StackR | |
final case class MoreR(acc: StackR, next: Int) extends StackR | |
sealed trait Stack[A] | |
final case class Done[A](result: Int) extends Stack[A] | |
final case class More[A](a: A, next: Int) extends Stack[A] | |
object Stack { | |
implicit val stackFunctor: Functor[Stack] = new Functor[Stack] { | |
override def map[A, B](sa: Stack[A])(f: A => B): Stack[B] = | |
sa match { | |
case Done(result) => Done(result) | |
case More(a, next) => More(f(a), next) | |
} | |
} | |
def done[A](result: Int = 1): Stack[A] = Done(result) | |
def more[A](a: A, next: Int): Stack[A] = More(a, next) | |
} | |
sealed trait NatR | |
case object ZeroR extends NatR | |
final case class SuccR(prev: NatR) extends NatR | |
sealed trait Nat[A] | |
final case class Zero[A]() extends Nat[A] | |
final case class Succ[A](a: A) extends Nat[A] | |
object Nat { | |
implicit val natFunctor: Functor[Nat] = new Functor[Nat] { | |
override def map[A, B](na: Nat[A])(f: A => B): Nat[B] = | |
na match { | |
case Zero() => Zero() | |
case Succ(a) => Succ(f(a)) | |
} | |
} | |
} | |
final case class Fix[F[_]](unfix: F[Fix[F]]) | |
final case class Cofree[F[_], A](head: A, tail: F[Cofree[F, A]]) | |
sealed trait Free[F[_], A] | |
final case class Continue[F[_], A](a: A) extends Free[F, A] | |
final case class Combine[F[_], A](fa: F[Free[F, A]]) extends Free[F, A] | |
object Free { | |
def continue[F[_], A](a: A): Free[F, A] = Continue(a) | |
def combine[F[_], A](fa: F[Free[F, A]]): Free[F, A] = Combine(fa) | |
} | |
object RecursionScheme { | |
import Free._ | |
import Stack._ | |
def unfoldStackR: Int => StackR = | |
n => if (n > 0) MoreR(unfoldStackR(n - 1), n) else DoneR() | |
// Type A => F[A] is also known as Coalgebra. | |
def ana[F[_] : Functor, A](f: A => F[A]): A => Fix[F] = | |
a => Fix(f(a) map ana(f)) | |
val stackCoalgebra: Int => Stack[Int] = | |
n => if (n > 0) more(n - 1, n) else done() | |
// Explicit recursion with StackR | |
def foldStackR: StackR => Int = { | |
case DoneR(result) => result | |
case MoreR(acc, next) => foldStackR(acc) * next | |
} | |
// Explicit recursion with Fix[Stack] | |
def foldFixStack: Fix[Stack] => Int = | |
_.unfix match { | |
case Done(result) => result | |
case More(fix, next) => foldFixStack(fix) * next | |
} | |
// Type F[A] => A is also known as Algebra. | |
def cata[F[_] : Functor, A](f: F[A] => A): Fix[F] => A = | |
fix => f(fix.unfix map cata(f)) | |
val stackAlgebra: Stack[Int] => Int = { | |
case Done(result) => result | |
case More(acc, next) => acc * next | |
} | |
def hyloSimple[F[_] : Functor, A, B](f: F[B] => B)(g: A => F[A]): A => B = | |
ana(g) andThen cata(f) | |
def hylo[F[_] : Functor, A, B](f: F[B] => B)(g: A => F[A]): A => B = | |
a => f(g(a) map hylo(f)(g)) | |
def para[F[_] : Functor, A](f: F[(Fix[F], A)] => A): Fix[F] => A = | |
fix => f(fix.unfix.map(fix => fix -> para(f).apply(fix))) | |
def cataViaPara[F[_] : Functor, A](f: F[A] => A): Fix[F] => A = | |
para(((_: F[(Fix[F], A)]).map(_._2)) andThen f) | |
val natAlgebra: Nat[Int] => Int = { | |
case Zero() => 1 | |
case Succ(n) => n + 1 | |
} | |
val natAlgebraPara: Nat[(Fix[Nat], Int)] => Int = { | |
case Zero() => 1 | |
case Succ((fix, acc)) => cata(natAlgebra).apply(fix) * acc | |
} | |
val natCoalgebra: Int => Nat[Int] = | |
n => if (n == 0) Zero() else Succ(n - 1) | |
def apo[F[_] : Functor, A](f: A => F[Either[Fix[F], A]]): A => Fix[F] = | |
a => Fix(f(a) map { | |
case Left(fix) => fix | |
case Right(aa) => apo(f).apply(aa) | |
}) | |
def anaViaApo[F[_] : Functor, A](f: A => F[A]): A => Fix[F] = | |
apo(f andThen (_.map(_.asRight[Fix[F]]))) | |
val lastThreeSteps: Fix[Stack] = Fix(More(Fix(More(Fix(More(Fix(Done(1)),1)),2)),3)) | |
val stackCoalgebraApo: Int => Stack[Either[Fix[Stack], Int]] = | |
n => if (n > 3) more(n - 1, n).map(_.asRight) else lastThreeSteps.unfix.map(_.asLeft) | |
def histo[F[_] : Functor, A](f: F[Cofree[F, A]] => A): Fix[F] => A = { | |
def toCofree: Fix[F] => Cofree[F, A] = | |
fix => Cofree(head = histo(f).apply(fix), tail = fix.unfix map toCofree) | |
fix => f(fix.unfix map toCofree) | |
} | |
def dynaSimple[F[_] : Functor, A, B](f: F[Cofree[F, B]] => B)(g: A => F[A]): A => B = | |
ana(g) andThen histo(f) | |
def dyna[F[_] : Functor, A, B](f: F[Cofree[F, B]] => B)(g: A => F[A]): A => B = { | |
val cofree: F[Cofree[F, B]] => Cofree[F, B] = | |
fc => Cofree(f(fc), fc) | |
a => hylo(cofree)(g).apply(a).head | |
} | |
def futu[F[_] : Functor, A](f: A => F[Free[F, A]]): A => Fix[F] = { | |
def toFix: Free[F, A] => Fix[F] = { | |
case Continue(a) => futu(f).apply(a) | |
case Combine(fa) => Fix(fa map toFix) | |
} | |
a => Fix(f(a) map toFix) | |
} | |
val firstThreeSteps: Stack[Free[Stack, Int]] = more(combine(more(continue(3), 4)), 5) | |
val stackCoalgebraFutu: Int => Stack[Free[Stack, Int]] = | |
n => | |
if (n == 5) firstThreeSteps | |
else if (n > 0) more(n - 1, n) map continue | |
else done() map continue | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is the code used in Recursion Schemes in Scala - An Absolutely Elementary Introduction