Skip to content

Instantly share code, notes, and snippets.

@nicmart
Created July 12, 2018 11:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nicmart/08724379cbf66cc78f69b13cf1b48702 to your computer and use it in GitHub Desktop.
Save nicmart/08724379cbf66cc78f69b13cf1b48702 to your computer and use it in GitHub Desktop.
Introduction to Finally Tagless Encodings
//2 + (3 + 4)
// 9 (evaluate)
// "(2 + (3 + 4))" (pretty printing)
trait AddLanguage[T] {
def literal(n: Int): T
def add(m: T, n: T): T
}
def expr[T](lang: AddLanguage[T]): T = {
import lang._
add(literal(2), add(literal(3), literal(4)))
}
val exprAsValue = new Expr {
override def expr[T](lang: AddLanguage[T]): T = {
import lang._
add(literal(2), add(literal(3), literal(4)))
}
}
trait Expr {
def expr[T](lang: AddLanguage[T]): T
}
object Evaluator extends AddLanguage[Int] {
override def literal(n: Int): Int = n
override def add(m: Int, n: Int): Int = m + n
}
object PrettyPrinter extends AddLanguage[String] {
override def literal(n: Int): String = n.toString
override def add(m: String, n: String): String = s"($m + $n)"
}
expr(Evaluator)
expr(PrettyPrinter)
trait MultLanguage[T] {
def mult(m: T, n: T): T
}
def expr2[T](langAdd: AddLanguage[T], langMult: MultLanguage[T]): T = {
import langAdd._, langMult._
add(literal(2), mult(literal(3), literal(4)))
}
object MultEvaluator extends MultLanguage[Int] {
override def mult(m: Int, n: Int): Int = m * n
}
expr2(Evaluator, MultEvaluator)
//2 + (3 + 4)
// 9 (evaluate)
// "(2 + (3 + 4))" (pretty printing)
def add(n: Int, m: Int): Int = n + m
val s = add(2, add(3, 4))
sealed trait Expr
final case class Literal(n: Int) extends Expr
final case class Add(m: Expr, n: Expr) extends Expr
val expr = Add(Literal(2), Add(Literal(3), Literal(4)))
def evaluate(expr: Expr): Int =
expr match {
case Literal(n) => n
case Add(m, n) => evaluate(m) + evaluate(n)
}
evaluate(expr)
def prettyPrint(expr: Expr): String =
expr match {
case Literal(n) => n.toString
case Add(n, m) => s"(${prettyPrint(n)} + ${prettyPrint(m)})"
}
prettyPrint(expr)
//2 * (3 + 4)
// 2 + (3 * 4)
// 14 (evaluate)
// "(2 * (3 + 4))" (pretty printing)
sealed trait ExprWithMult
case class Mult(m: ExprWithMult, n: ExprWithMult) extends ExprWithMult
case class Wrapped(expr: Expr) extends ExprWithMult
val expr2 = Mult(Wrapped(Literal(2)), Wrapped(Add(Literal(3), Literal(4))))
//Does not compile!
//val expr3 = Wrapped(Add(Literal(2), Mult(???, ???)))
def evaluate2(expr: ExprWithMult): Int = expr match {
case Mult(m: ExprWithMult, n: ExprWithMult) => evaluate2(m) * evaluate2(n)
case Wrapped(exprWithAdd) => evaluate(exprWithAdd)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment