Skip to content

Instantly share code, notes, and snippets.

@igor-ramazanov
Last active March 29, 2022 11:26
Show Gist options
  • Save igor-ramazanov/3bf9e0ca2363973b6cbaffa86f78005d to your computer and use it in GitHub Desktop.
Save igor-ramazanov/3bf9e0ca2363973b6cbaffa86f78005d to your computer and use it in GitHub Desktop.
Small notes I took during reading of Oleg Kiselyov's papers about Tagless Final: http://okmij.org/ftp/tagless-final/course/lecture.pdf
//==============================
//Part 1: Introduction
//==============================
//Introduction to the "initial" encoding (also called Free monads/applicatives)
//and the "final" encodings (also called "Tagless Final")
//Main purpose of the both ways:
//1. create a strict and statically typed DSL languages
//2. write a "description" of a program using these languages
//3. (optional) perform possible various introspections or optimisations on the created "description" of a program
//4. define multiple possible ways of interpreting the encoded program
//5. run the program
//We will model a simple arithmetic typed DSL language
//using different encoding styles of the DSL language into the host language (Scala)
//"Initial" encoding of the language - uses ADT (Algebraic Data Types)
trait Expression
case class Literal(x: Int) extends Expression
case class Negation(v: Expression) extends Expression
case class Addition(a: Expression, b: Expression) extends Expression
//Our example we will work with
//8 + (- (1 + 2))
//Program in "initial" encoding
val ti1 = Addition(
Literal(8),
Negation(
Addition(
Literal(1),
Literal(2))))
//Evaluation of the program
def eval(e: Expression): Int = e match {
case Literal(v) => v
case Negation(a) => -eval(a)
case Addition(a, b) => eval(a) + eval(b)
}
eval(ti1) // 5: Int
//"Final" encoding of the language using expressions
type Representation = Int
def literal(n: Int): Representation = n
def negation(e: Representation): Representation = -e
def addition(a: Representation, b: Representation): Representation = a + b
//Evaluation of the program as expressions
val tf1 = addition(
literal(8),
negation(
addition(
literal(1),
literal(2)))) // 5: Int
//Different possible evaluation of the "initially"-encoded program
def view(e: Expression): String = e match {
case Literal(v) => v.toString
case Negation(a) => "(-" + view(a) + ")"
case Addition(a, b) => "(" + view(a) + "+" + view(b) + ")"
}
view(ti1) // (8+(-(1+2))): String
//The above "final" encoding doesn't allow to parameterize interpreters,
//so now we define "final" encoding using type-classes
trait ExprSymantics[Repr] {
def lit(v: Int): Repr
def neg(a: Repr): Repr
def add(a: Repr, b: Repr): Repr
}
//Different interpreters parameterised by a resulted type
implicit val expSymInt: ExprSymantics[Int] = new ExprSymantics[Int] {
override def lit(v: Int) = v
override def neg(a: Int) = -a
override def add(a: Int, b: Int) = a + b
}
implicit val expSymString: ExprSymantics[String] = new ExprSymantics[String] {
override def lit(v: Int): String = v.toString
override def neg(a: String): String = s"(-$a)"
override def add(a: String, b: String): String = s"($a+$b)"
}
//The above example program parameterised by a resulted type
def expr1[Repr: ExprSymantics]: Repr = {
val evidence = implicitly[ExprSymantics[Repr]]
import evidence._
add(lit(8), neg(add(lit(1), lit(2))))
}
//Substitution of different interpreters
expr1[Int] // 5: Int
expr1[String] // (8+(-(1+2))): String
//==============================
//Part 2: Extending the language
//==============================
//Extending the language using "initial" encoding:
//1. adding a new branch of ADT
//2. update evaluation functions
//causes changes or at least recompiling of all the dependent code
//this is an "expression problem": easy to add new operation on data, but hard to add new data variants
case class Mult(a: Expression, b: Expression) extends Expression
//Extending the language using "final" encoding:
//1. adding a new type-class
//2. adding a new interpreters
//3. combining a new interpreter with previous ones
trait MultSYM[Repr] {
def mul(a: Repr, b: Repr): Repr
}
//example: (7 + (-(1 * 2)))
def expr2[Repr: ExprSymantics : MultSYM]: Repr= {
val E = implicitly[ExprSymantics[Repr]]
val M = implicitly[MultSYM[Repr]]
import E._
import M._
add(lit(7), neg(mul(lit(1), lit(2))))
}
//We have to implement additional interpreters (Scala's SAM feature is used)
implicit val multSymInt: MultSYM[Int] = (a: Int, b: Int) => a * b
implicit val multSymString: MultSYM[String] = (a: String, b: String) =>
s"($a*$b)"
expr2[Int] // 5
expr2[String] // (7+(-(1*2)))
//As you see, extending of the language
//using "final" encoding doesn't cause changes in existing code
//DSL language becomes easily extensible
//and extension mismatches are caught by the type-checker
//==============================
//Part 3: The de-serialization problem
//==============================
//Main statement:
//serialisation is simple (like above converting of the program to String)
//but deserialization is much harder
//Our target JSON-like format for serialisation:
trait Tree
case class Leaf(v: String) extends Tree
case class Node(v: String, ts: List[Tree]) extends Tree
//Serialisation:
//Serializer-interpreter:
implicit val expSymTree: ExprSymantics[Tree] = new ExprSymantics[Tree] {
override def lit(v: Int) = Node("Literal", List(Leaf(v.toString)))
override def neg(a: Tree) = Node("Negation", List(a))
override def add(a: Tree, b: Tree) = Node("Addition", List(a, b))
}
// 8 + (-(1 + 2))
val tree = expr1[Tree]
//Node(Addition,List(
// Node(Literal,List(Leaf(8))),
// Node(Negation,List(
// Node(Addition,List(
// Node(Literal,List(Leaf(1))),
// Node(Literal,List(Leaf(2))))))))): Tree
//Deserialization:
//The input "Tree" structure may be invalid, so we need to handle errors
//We'll use the "Either" for this
type ErrMsg = String
def safeReadInt(s: String): Either[ErrMsg, Int] = {
import scala.util.Try
Try(s.toInt).toEither.left.map(_ => s"Couldn't parse to Int: '$s'")
}
//Here we want to convert a Tree directly into a result of a certain interpreter
//The 'A' type parameter is used for interpreter substitution
def fromTree[A: ExprSymantics](tree: Tree): Either[ErrMsg, A] = {
val E = implicitly[ExprSymantics[A]]
import E._
tree match {
case Node("Literal", List(Leaf(n))) =>
safeReadInt(n).right.map(lit)
case Node("Negation", List(subTree)) =>
fromTree(subTree).right.map(neg)
case Node("Addition", List(leftSubTree, rightSubTree)) =>
for (lt <- fromTree(leftSubTree); rt <- fromTree(rightSubTree))
yield add(lt, rt)
case _ =>
Left("Invalid tree")
}
}
fromTree[Int](tree) // Right(5): Either[ErrMsg, Int]
fromTree[String](tree) // Right((8+(-(1+2)))): Either[ErrMsg,String]
//This works but we can not construct a finally tagless tree representation waiting for later interpretation
//Let's define a Wrapper type for it
trait Wrapped {
def value[A: ExprSymantics]: A
}
//And write an interpeter from tree to Wrapped
implicit val WrappedInterpreter: ExprSymantics[Wrapped] = new ExprSymantics[Wrapped] {
override def lit(v: Int): Wrapped = new Wrapped {
override def value[A: ExprSymantics]: A = implicitly[ExprSymantics[A]].lit(v)
}
override def neg(a: Wrapped): Wrapped = new Wrapped {
override def value[A: ExprSymantics]: A = implicitly[ExprSymantics[A]].neg(a.value)
}
override def add(a: Wrapped, b: Wrapped): Wrapped = new Wrapped {
override def value[A: ExprSymantics]: A = implicitly[ExprSymantics[A]].add(a.value, b.value)
}
}
fromTree[Wrapped](tree) match {
case Left(err) => println(err)
case Right(wrapped) =>
//here we can reuse wrapped value
wrapped.value[Int] // 5
wrapped.value[String] // (8+(-(1+2)))
}
//Everything is fine but we still lack extensibility
//In order to add MultSYM we have to rewrite and recompile 'fromTree' function
//adding new 'case' clause
//This problem can be solved using open-recursion style
//Let's rewrite the 'fromTree' function in that style
def fromTreeExt[A: ExprSymantics]
(recur: => (Tree => Either[ErrMsg, A]))
: Tree => Either[ErrMsg, A] = {
val E = implicitly[ExprSymantics[A]]
import E._
tree => tree match {
case Node("Literal", List(Leaf(n))) =>
safeReadInt(n).right.map(lit)
case Node("Negation", List(subTree)) =>
recur(subTree).right.map(neg)
case Node("Addition", List(leftSubTree, rightSubTree)) =>
for (lt <- recur(leftSubTree); rt <- recur(rightSubTree))
yield add(lt, rt)
case _ =>
Left("Invalid tree")
}
}
//Fix point operator
def fix[A](f: (=> A) => A): A = f(fix(f))
def fromTree2[A: ExprSymantics](t: Tree): Either[ErrMsg, A] = fix(fromTreeExt[A] _)(t)
fromTree2[Int](tree) // Right(5)
fromTree2[String](tree) // Right((8+(-(1+2))))
//Here we defining new deserialisation logic without touching a previous code
def fromTreeExt2[A: ExprSymantics: MultSYM]
(recur: => (Tree => Either[ErrMsg, A]))
: Tree => Either[ErrMsg, A] = {
val E = implicitly[ExprSymantics[A]]
val M = implicitly[MultSYM[A]]
import E._
import M._
{
case Node("Multiplication", List(leftSubTree, rightSubTree)) =>
for (lt <- recur(leftSubTree); rt <- recur(rightSubTree))
yield mul(lt, rt)
case t => fromTreeExt(recur).apply(t)
}
}
def fromTree3[A: ExprSymantics: MultSYM](t: Tree): Either[ErrMsg, A] = fix(fromTreeExt2[A] _)(t)
implicit val multSymTree: MultSYM[Tree] = new MultSYM[Tree] {
def mul(a: Tree, b: Tree): Tree = Node("Multiplication", List(a, b))
}
def richProgram[A: ExprSymantics : MultSYM]: A = {
val E = implicitly[ExprSymantics[A]]
val M = implicitly[MultSYM[A]]
import E._
import M._
mul(lit(10), add(lit(2), lit(3)))
}
val treeWithMult = richProgram[Tree]
fromTree3[String](treeWithMult) // Right((10*(2+3)))
fromTree3[Int](treeWithMult) // Right(50)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment