Instantly share code, notes, and snippets.

# tovbinm/RecursionSchemes.scala Created May 26, 2017

Solutions to Recursion Schemes - http://slides.com/jedesah/recursion-schemes
 package rs import scala.language.higherKinds import matryoshka.data._ import matryoshka.implicits._ import scalaz._, Scalaz._ trait Expr[A] case class NumLit[A](value: Int) extends Expr[A] case class Add[A](left: A, right: A) extends Expr[A] case class Div[A](num: A, denum: A) extends Expr[A] case class DivisionByZero(div: Div[Int]) object Expr { implicit val traverse: Traverse[Expr] = new Traverse[Expr] { def traverseImpl[G[_] : Applicative, A, B](expr: Expr[A])(f: A => G[B]): G[Expr[B]] = expr match { case NumLit(value) => (NumLit(value): Expr[B]).point[G] case Add(left, right) => (f(left) |@| f(right)) (Add(_, _)) case Div(num, denum) => (f(num) |@| f(denum)) (Div(_, _)) } } } object Main extends App { def complexity(expr: Fix[Expr]): Int = expr.cata[Int] { case NumLit(value) => 1 case Add(left, right) => 1 + Math.max(left, right) case Div(num, denum) => 1 + Math.max(num, denum) } def incr(expr: Fix[Expr]): Fix[Expr] = expr.cata[Fix[Expr]] { case NumLit(value) => Fix(NumLit(value + 1)) case other => Fix(other) } type ErrorOr[A] = DivisionByZero \/ A def eval(expr: Fix[Expr]): ErrorOr[Int] = expr.cataM[ErrorOr, Int] { case NumLit(value) => value.right case Add(left, right) => (left + right).right case d@Div(num, denum) => if (denum != 0) (num / denum).right else DivisionByZero(d).left } def collect(a: Fix[Expr]): List[NumLit[_]] = a.cata[List[NumLit[_]]] { case n@NumLit(value) => List(n) case other => other.fold } def gen(complexity: Int): Option[Fix[Expr]] = { complexity.anaM[Fix[Expr]][Option, Expr] { case 0 => None case 1 => Some(NumLit(1)) case n => Some(Add(n - 1, n - 1)) } } val expr: Fix[Expr] = Fix(Add(Fix(NumLit(5)), Fix(NumLit(10)))) println("expr := " + expr) println("complexity := " + complexity(expr)) println("incr := " + incr(expr)) println("eval 1+1 := " + eval(Fix(Add(Fix(NumLit(1)), Fix(NumLit(1)))))) println("eval 6/3 := " + eval(Fix(Div(Fix(NumLit(6)), Fix(NumLit(3)))))) println("eval 5/0 := " + eval(Fix(Div(Fix(NumLit(5)), Fix(NumLit(0)))))) println("collect := " + collect(Fix(Add(Fix(NumLit(1)), Fix(Add(Fix(NumLit(2)), Fix(NumLit(3)))))))) println("gen(0) := " + gen(0).map(complexity)) println("gen(1) := " + gen(1).map(complexity)) println("gen(2) := " + gen(2).map(complexity)) println("gen(3) := " + gen(3).map(complexity)) }