Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Solutions to Recursion Schemes - http://slides.com/jedesah/recursion-schemes
package rs
import scala.language.higherKinds
import matryoshka.data._
import matryoshka.implicits._
import scalaz._, Scalaz._
trait Expr[A]
case class NumLit[A](value: Int) extends Expr[A]
case class Add[A](left: A, right: A) extends Expr[A]
case class Div[A](num: A, denum: A) extends Expr[A]
case class DivisionByZero(div: Div[Int])
object Expr {
implicit val traverse: Traverse[Expr] = new Traverse[Expr] {
def traverseImpl[G[_] : Applicative, A, B](expr: Expr[A])(f: A => G[B]): G[Expr[B]] = expr match {
case NumLit(value) => (NumLit(value): Expr[B]).point[G]
case Add(left, right) => (f(left) |@| f(right)) (Add(_, _))
case Div(num, denum) => (f(num) |@| f(denum)) (Div(_, _))
}
}
}
object Main extends App {
def complexity(expr: Fix[Expr]): Int = expr.cata[Int] {
case NumLit(value) => 1
case Add(left, right) => 1 + Math.max(left, right)
case Div(num, denum) => 1 + Math.max(num, denum)
}
def incr(expr: Fix[Expr]): Fix[Expr] = expr.cata[Fix[Expr]] {
case NumLit(value) => Fix(NumLit(value + 1))
case other => Fix(other)
}
type ErrorOr[A] = DivisionByZero \/ A
def eval(expr: Fix[Expr]): ErrorOr[Int] = expr.cataM[ErrorOr, Int] {
case NumLit(value) => value.right
case Add(left, right) => (left + right).right
case d@Div(num, denum) => if (denum != 0) (num / denum).right else DivisionByZero(d).left
}
def collect(a: Fix[Expr]): List[NumLit[_]] = a.cata[List[NumLit[_]]] {
case n@NumLit(value) => List(n)
case other => other.fold
}
def gen(complexity: Int): Option[Fix[Expr]] = {
complexity.anaM[Fix[Expr]][Option, Expr] {
case 0 => None
case 1 => Some(NumLit(1))
case n => Some(Add(n - 1, n - 1))
}
}
val expr: Fix[Expr] = Fix(Add(Fix(NumLit(5)), Fix(NumLit(10))))
println("expr := " + expr)
println("complexity := " + complexity(expr))
println("incr := " + incr(expr))
println("eval 1+1 := " + eval(Fix(Add(Fix(NumLit(1)), Fix(NumLit(1))))))
println("eval 6/3 := " + eval(Fix(Div(Fix(NumLit(6)), Fix(NumLit(3))))))
println("eval 5/0 := " + eval(Fix(Div(Fix(NumLit(5)), Fix(NumLit(0))))))
println("collect := " + collect(Fix(Add(Fix(NumLit(1)), Fix(Add(Fix(NumLit(2)), Fix(NumLit(3))))))))
println("gen(0) := " + gen(0).map(complexity))
println("gen(1) := " + gen(1).map(complexity))
println("gen(2) := " + gen(2).map(complexity))
println("gen(3) := " + gen(3).map(complexity))
}

Thanks, this was very helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment