Create a gist now

Instantly share code, notes, and snippets.

Embed
Solutions to Recursion Schemes - http://slides.com/jedesah/recursion-schemes
package rs
import scala.language.higherKinds
import matryoshka.data._
import matryoshka.implicits._
import scalaz._, Scalaz._
trait Expr[A]
case class NumLit[A](value: Int) extends Expr[A]
case class Add[A](left: A, right: A) extends Expr[A]
case class Div[A](num: A, denum: A) extends Expr[A]
case class DivisionByZero(div: Div[Int])
object Expr {
implicit val traverse: Traverse[Expr] = new Traverse[Expr] {
def traverseImpl[G[_] : Applicative, A, B](expr: Expr[A])(f: A => G[B]): G[Expr[B]] = expr match {
case NumLit(value) => (NumLit(value): Expr[B]).point[G]
case Add(left, right) => (f(left) |@| f(right)) (Add(_, _))
case Div(num, denum) => (f(num) |@| f(denum)) (Div(_, _))
}
}
}
object Main extends App {
def complexity(expr: Fix[Expr]): Int = expr.cata[Int] {
case NumLit(value) => 1
case Add(left, right) => 1 + Math.max(left, right)
case Div(num, denum) => 1 + Math.max(num, denum)
}
def incr(expr: Fix[Expr]): Fix[Expr] = expr.cata[Fix[Expr]] {
case NumLit(value) => Fix(NumLit(value + 1))
case other => Fix(other)
}
type ErrorOr[A] = DivisionByZero \/ A
def eval(expr: Fix[Expr]): ErrorOr[Int] = expr.cataM[ErrorOr, Int] {
case NumLit(value) => value.right
case Add(left, right) => (left + right).right
case d@Div(num, denum) => if (denum != 0) (num / denum).right else DivisionByZero(d).left
}
def collect(a: Fix[Expr]): List[NumLit[_]] = a.cata[List[NumLit[_]]] {
case n@NumLit(value) => List(n)
case other => other.fold
}
def gen(complexity: Int): Option[Fix[Expr]] = {
complexity.anaM[Fix[Expr]][Option, Expr] {
case 0 => None
case 1 => Some(NumLit(1))
case n => Some(Add(n - 1, n - 1))
}
}
val expr: Fix[Expr] = Fix(Add(Fix(NumLit(5)), Fix(NumLit(10))))
println("expr := " + expr)
println("complexity := " + complexity(expr))
println("incr := " + incr(expr))
println("eval 1+1 := " + eval(Fix(Add(Fix(NumLit(1)), Fix(NumLit(1))))))
println("eval 6/3 := " + eval(Fix(Div(Fix(NumLit(6)), Fix(NumLit(3))))))
println("eval 5/0 := " + eval(Fix(Div(Fix(NumLit(5)), Fix(NumLit(0))))))
println("collect := " + collect(Fix(Add(Fix(NumLit(1)), Fix(Add(Fix(NumLit(2)), Fix(NumLit(3))))))))
println("gen(0) := " + gen(0).map(complexity))
println("gen(1) := " + gen(1).map(complexity))
println("gen(2) := " + gen(2).map(complexity))
println("gen(3) := " + gen(3).map(complexity))
}
@timjstewart

This comment has been minimized.

Show comment
Hide comment
@timjstewart

timjstewart May 27, 2017

Thanks, this was very helpful!

Thanks, this was very helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment