Skip to content

Instantly share code, notes, and snippets.

@OlivierBlanvillain
Last active January 25, 2024 15:59
Show Gist options
  • Save OlivierBlanvillain/48bb5c66dbb0557da50465809564ee80 to your computer and use it in GitHub Desktop.
Save OlivierBlanvillain/48bb5c66dbb0557da50465809564ee80 to your computer and use it in GitHub Desktop.

Revisiting Tagless Final Interpreters

Tageless Final interpreters are an alternative to the traditional Algebraic Data Type (and generalized ADT) based implementation of the interpreter pattern. This document presents the Tageless Final approach with Scala, and shows how Dotty with it's recently added implicits functions makes the approach even more appealing. All examples are direct translations of their Haskell version presented in the Typed Tagless Final Interpreters: Lecture Notes (section 2).

The interpreter pattern has recently received a lot of attention in the Scala community. A lot of efforts have been invested in trying to address the biggest shortcomings of ADT/GADT based solutions: extensibility. One can first look at cats' Inject typeclass for an implementation of Data Type à la Carte ideas. The Freek library provides facilities to combine more than two algebras using some pretty involved type-level machinery. The solution proposed by the Freer Monads, More Extensible Effects paper also focuses on extensibility, and inspired a bunch of work of small Scala libraries such as eff, emm and paperdoll. Tagless Final interpreters take a somewhat of a dual approach by having typeclasses at their very core instead of the more traditional ADT/GADT. They also come with the great advantage of being out of the box extensible without having any apparent pitfall.

The lecture notes goes into great details in presenting and comparing both ADT/GADT and typeclass based approaches, referencing the former as initial and the later as final. For the sake of conciseness, this document focuses mainly on final interpreters.

Introduction

We will be working with simple mathematic expressions similar to the ones manipulated on calculators. Our task will not only consist in computing these expressions, but also serialize, deserialize, and simplify them. As lazy engineers, it makes perfect sense to represent the domain using an (embedded) domain specific language. Among other things this saves us from caring about all the possible incorrect representations of our domain: the host language compiler takes care of that for us. Our domain consists in integer literals with addition and negation. The encoding that might have popped right up into your mind might look like the following:

sealed trait IExp
final case class Lit(i: Int) extends IExp
final case class Neg(e: IExp) extends IExp
final case class Add(r: IExp, l: IExp) extends IExp

A mathematical expression such as 8 - (1 + 2) can be encoded as a value of type IExp:

val fe: IExp = Add(Lit(8), Neg(Add(Lit(1), Lit(2))))

Now everything looks easy right? Interpreting fe as integer uses recursive function of type IExp => Int, serializing is done with a IExp => Json function, deserialize goes the other way around with Json => Option[IExp], and transformations are done IExp => IExp functions.

In the lecture notes' terminology, the IExp data type corresponds to the initial encoding of our domain. You can forget about it for now, as we will instead be using the final encoding:

trait Exp[T] {
  def lit(i: Int): T
  def neg(t: T): T
  def add(l: T, r: T): T
}

How do we represent 8 - (1 + 2) with Exp? Something like the following:

def tf0[T](implicit e: Exp[T]): T =
  e.add(e.lit(8), e.neg(e.add(e.lit(1), e.lit(2))))

In Haskell (with proper language extension) tf0 could be a polymorphic value. In Scala we use a function with a type parameter T and an implicit Exp[T] constraint. The syntax can be simplified (for example) by getting rid of the e. using helper functions:

object ExpSyntax {
  def lit[T](i: Int)    (implicit e: Exp[T]): T = e.lit(i)
  def neg[T](t: T)      (implicit e: Exp[T]): T = e.neg(t)
  def add[T](l: T, r: T)(implicit e: Exp[T]): T = e.add(l, r)
}
import ExpSyntax._ // It's safe to always have these in scope

def tf1[T: Exp]: T =
  add(lit(8), neg(add(lit(1), lit(2))))

At this point you probably wonder, how do we write interpreters for this tf1 thing?. The answer is simple, by creating instances of Exp!

implicit val evalExp: Exp[Int] = new Exp[Int] {
  def lit(i: Int): Int = i
  def neg(t: Int): Int = -t
  def add(l: Int, r: Int): Int = l + r
}
implicit val printExp: Exp[String] = new Exp[String] {
  def lit(i: Int): String = i.toString
  def neg(t: String): String = s"(-$t)"
  def add(l: String, r: String): String = s"($l + $r)"
}

Interpretation is done by specifying types. Let's looks at tf1 as an Int:

scala> tf1[Int]
res0: Int = 5

What about tf1 as a String?

scala> tf1[String]
res1: String = (8 + (-(1 + 2)))

Extensibility

What if we decide to extend our mathematical expressions with multiplication? With the initial (ADT based) IExp encoding, we would be confronted with two inconvenient options: update the definition of the IExp data type and all the interpreters we wrote so far, or relying on some sort Data Type à la Carte style lifting of IExp value and interpreters in coproducts. This is where the final tagless approach truly shines, it can naturally be extended without breaking (or even recompiling) any of the existing code. We introduce a new, completely independent type class for the multiplication:

trait Mult[T] {
  def mul(l: T, r: T): T
}

object MultSyntax {
  def mul[T](l: T, r: T)(implicit e: Mult[T]): T = e.mul(l, r)
}
import MultSyntax._

Expressions using multiplications need an additional Mult constraint (a additional Mult[T] implicit argument, that is). Here is how we define tfm1 = 7 - 1 * 2 and tfm2 = 7 * (8 - (1 + 2))`:

def tfm1[T: Exp : Mult] = add(lit(7), neg(mul(lit(1), lit(2))))
def tfm2[T: Exp : Mult] = mul(lit(7), tf1)

If you are not satisfied with having : Exp : Mult everywhere, I'll spoil what we will see by the end of this article: these can be factored out in Dotty using an implicit-function-type.

To interpret these newly created expressions we need to provide Mult instances for Int and String:

implicit val evalMult: Mult[Int] = new Mult[Int] {
  def mul(l: Int, r: Int): Int = l * r
}

implicit val printMult: Mult[String] = new Mult[String] {
  def mul(l: String, r: String): String = s"$l * $r"
}

Without any additional wiring Exp and Mult instances are automatically combined when doing interpretation:

scala> tfm1[String]
res2: String = (7 + (-1 * 2))
scala> tfm1[Int]
res3: Int = 5
scala> tfm2[String]
res4: String = 7 * (8 + (-(1 + 2)))
scala> tfm2[Int]
res4: Int = 35

Deserialization

Let's move on to the more complicated problem of deserialization. The targeted format is a Json-like tree structure defined as follows:

sealed trait Tree
final case class Leaf(s: String) extends Tree
final case class Node(s: String, ts: List[Tree]) extends Tree

Transforming expressions to this Json-like format is no more complicated than serializing to String. Depending on where Exp and Mult instances are defined, it's also possible to group the together:

implicit val toTree: Exp[Tree] with Mult[Tree] = new Exp[Tree] with Mult[Tree] {
  def lit(i: Int): Tree = Node("Lit", List(Leaf(i.toString)))
  def neg(t: Tree): Tree = Node("Neg", List(t))
  def add(l: Tree, r: Tree): Tree = Node("Add", List(l , r))
  def mul(l: Tree, r: Tree): Tree = Node("Mult", List(l , r))
}
scala> val tf1Tree = tf1[Tree]
tf1Tree: Tree = Node(Add,List(Node(Lit,List(Leaf(8))), Node(Neg,List(Node(Add,List(Node(Lit,List(Leaf(1))), Node(Lit,List(Leaf(2)))))))))

For deserialization, we need to write a fromTree function to convert a Json-like Tree to it's final encoding. Given that our finally encoded values are functions of [T] => Exp[T] => T (in Dotty this is the syntax for the type lambda ({type L[T] = Exp[T] => T})#L), our first intuition might be to define fromTree as def fromTree[T](t: Tree)(implicit e: Exp[T]): Either[ErrMsg, T]:

type ErrMsg = String

def readInt(s: String): Either[ErrMsg, Int] = {
  import scala.util.{Try, Success, Failure}
  Try(s.toInt) match {
    case Success(i) => Right(i)
    case Failure(f) => Left(f.toString)
  }
}

def fromTree[T](t: Tree)(implicit e: Exp[T]): Either[ErrMsg, T] =
  t match {
    case Node("Lit", List(Leaf(n))) =>
      readInt(n).right.map(e.lit)

    case Node("Neg", List(t)) =>
      fromTree(t).right.map(e.neg)

    case Node("Add", List(l , r)) =>
      for(lt <- fromTree(l).right; rt <- fromTree(r).right)
      yield e.add(lt, rt)

    case _ => Left(s"Invalid tree $t")
  }

This would work, but because T and Exp[T] both need to be fully specified when calling fromTree, polymorphism is lost and the result of fromTree can only have a single interpretation. We can work around this issue by wrapping results using the following:

trait Wrapped {
  def value[T](implicit e: Exp[T]): T
}

The lecture notes then suggest than fromTree could be rewritten with a new signature: def fromTree(t: Tree): Either[ErrMsg, Wrapped], but I think they missed that we can actually achieve the same result by defining a Exp[Wrapped] instance and reusing our first fromTree implementation:

implicit val wrappingExp: Exp[Wrapped] = new Exp[Wrapped] {
  def lit(i: Int) = new Wrapped {
    def value[T](implicit e: Exp[T]): T = e.lit(i)
  }
  def neg(t: Wrapped) = new Wrapped {
    def value[T](implicit e: Exp[T]): T = e.neg(t.value)
  }
  def add(l: Wrapped, r: Wrapped) = new Wrapped {
    def value[T](implicit e: Exp[T]): T = e.add(l.value, r.value)
  }
}

This is enough to fake first class polymorphism!

scala> fromTree[Wrapped](tf1Tree) match {
     |  case Left(err) =>
     |    println(err)
     |
     |  case Right(t) =>
     |    println(t.value[Int])
     |    println(t.value[String])
     |    println
     |}
5
(8 + (-(1 + 2)))

We only solved half of the problems: our deserializer still lacks extensibility. In order to allow multiplication to be added after the fact, fromTree needs to be rewritten in the open recursion style. That's another one of these scary names for a very simple idea: we can rewrite all recursive calls of fromTree to go through an additional recur parameter:

// Note that `recur` and `fromTree _` have the same type!
def fromTreeExt[T]
  (recur: => (Tree => Either[ErrMsg, T]))
  (implicit e: Exp[T])
  : Tree => Either[ErrMsg, T] = {
    val e = implicitly[Exp[T]]
    tree => tree match {
      case Node("Lit", List(Leaf(n))) =>
        readInt(n).right.map(e.lit)

      case Node("Neg", List(t)) =>
        recur(t).right.map(e.neg)

      case Node("Add", List(l , r)) =>
        for(lt <- recur(l).right; rt <- recur(r).right)
        yield e.add(lt, rt)

      case t => Left(s"Invalid tree $t")
    }
  }

The recursive knot is tied subsequently using the fix point operator:

def fix[A](f: (=> A) => A): A = f(fix(f))
def fromTree2[T: Exp](t: Tree): Either[ErrMsg, T] = fix(fromTreeExt[T] _)(t)

This way, it becomes possible to define the multiplication deserializer separately, and knot the tie a second time:

def fromTreeExt2[T]
  (recur: => (Tree => Either[ErrMsg, T]))
  (implicit e: Exp[T], m: Mult[T])
  : Tree => Either[ErrMsg, T] = {
    case Node("Mult", List(l , r)) =>
      for(lt <- recur(l).right; rt <- recur(r).right)
      yield m.mul(lt, rt)

    case t => fromTreeExt(recur).apply(t)
  }
def fromTree3[T: Exp : Mult](t: Tree): Either[ErrMsg, T] = fix(fromTreeExt2[T] _)(t)

We can now test that our serialization and deserialization and inverse operation for any e, for instance using:

assert(fromTreeN[String](e[Tree]) == Right(e[String]))

Note that in Scala this implementation is not stack-safe. We would need to add trampolining to use this open recursion trick with large data structures.

Transformation

We've seen how to transform our mathematical expressions to various other representation, that was interpretation, how to build our mathematical expressions from another representation, deserialization, but what about transforming a mathematical expression into another mathematical expression?

We consider the transformation that pushed negation all the way down to the literal, such that 8 - (1 + 2) becomes 8 + ((-1) + (-2)). The task sounds easy with initial (ADT based) encoding, a simple IExp => IExp function would do the job:

def pushNeg(e: IExp): IExp = e match {
  case Lit(_) => e
  case Neg(Lit(_)) => e
  case Neg(Neg(n)) => n
  case Neg(Add(l, r)) => Add(pushNeg(Neg(l)), pushNeg(Neg(r)))
  case Add(l, r) => Add(pushNeg(l), pushNeg(r))
}

This looks impossible on the final encoding. The pattern matching is very convenient here do contextualized transformations, how can we achieve the same inside an instance of Exp? Instead of manipulating an Exp[T] like we did so far, the trick is to work a Exp[Ctx => T] with the appropriate context. In this case, the context is quite simple, all we need is to know whether or not the current node is being negated:

sealed trait NCtx
final case object PosCtx extends NCtx
final case object NegCtx extends NCtx

The transformation is expressed as Exp[NCtx => T]:

implicit def negDownExp[T](implicit e: Exp[T]): Exp[NCtx => T] = new Exp[NCtx => T] {
  def lit(i: Int): NCtx => T = {
    case PosCtx => e.lit(i)
    case NegCtx => e.neg(e.lit(i))
  }

  def neg(x: NCtx => T): NCtx => T = {
    case PosCtx => x(NegCtx)
    case NegCtx => x(PosCtx)
  }

  def add(l: NCtx => T, r: NCtx => T): NCtx => T =
    c => e.add(l(c), r(c))
}

To apply the transformation, one first needs to transform an expression in a NCtx => T function, then call it with the initial context:

scala> tf1[NCtx => String].apply(PosCtx)
(8 + ((-1) + (-2)))

The initial context can also be factored out in a function:

scala> def pushNeg[T](e: NCtx => T): T = e(PosCtx)
pushNeg: [T](e: NCtx => T)T

scala> pushNeg(tf1[NCtx => String])
(8 + ((-1) + (-2)))

Unfortunately scalac type inference requires the inner type parameter in this case, which can become quite ugly when composing several transformations: pushNeg(pushNeg(pushNeg(tf1[NCtx => NCtx => NCtx => String]))). Improvements in Dotty's Type inference make to possible to write pushNeg(pushNeg(pushNeg(tf1))): String, similarly to what you would do in Haskell. See Dotty and types: the story so far for an introduction to Dotty's type inference.

This transformation can naturally be extended for the multiplication by defining the complementary Mult[NCtx => T] instance:

implicit def negDownMult[T](implicit e: Mult[T]): Mult[NCtx => T] = new Mult[NCtx => T] {
  def mul(l: NCtx => T, r: NCtx => T): NCtx => T = {
    case PosCtx => e.mul(l(PosCtx), r(PosCtx))
    case NegCtx => e.mul(l(PosCtx), r(NegCtx))
  }
}
scala> pushNeg(tfm1[NCtx => String])
(7 + 1 * (-2))

scala> pushNeg(tfm2[NCtx => String])
7 * (8 + ((-1) + (-2)))

The lecture continues with another example of transformation using a similar contextualization trick: the flattening of additions. The transformations we have seen so far were quite creative, and you might wonder if everything that we can write using the final encoding can also be expressed on the initial encoding, and vice versa. The two representations are in fact equivalent, which can be showed by the existence of a bijection:

// Going from type class encoding to ADT encoding
implicit def initialize: Exp[IExp] = new Exp[IExp] {
  def lit(i: Int): IExp = Lit(i)
  def neg(t: IExp): IExp = Neg(t)
  def add(l: IExp, r: IExp): IExp = Add(l, r)
}

// Going from ADT encoding to type class encoding
def finalize[T](i: IExp)(implicit e: Exp[T]): T = i match {
  case Lit(l) => e.lit(l)
  case Neg(n) => e.neg(finalize[T](n))
  case Add(l, r) => e.add(finalize[T](l), finalize[T](r))
}

Grouping typeclass constraints with implicit function types

Implicit functions are a recent addition to the Dotty compiler. The idea was to extend the syntax currently available for functions to support functions with implicit argument. As you probably know, scala functions are defined as follows:

trait Function1[A, B] {
  def apply(a: A): B
}

The compiler implements syntactic sugars to transform A => B types into Function1[A, B] and let users define function values in concise ways. Implicit functions are analogous: implicit A => B becomes a valid type desugared into ImplicitFunction1[A, B]:

trait ImplicitFunction1[A, B] {
  def apply(implicit a: A): B
}

Defining a function returning an implicit-function-type benefit from additional desugaring to automatically put the implicits in scope:

def f: implicit Ctx => Unit = ???

Desugars into:

def f: implicit Ctx => Unit = { implicit $e1: Ctx => ???: Unit }

The syntactic sugars might not look very useful this simple example, but with a type alias it already becomes more interesting:

type Contextualized[T] = implicit Ctx => T

def f: Contextualized[Unit] = ???

With more than one implicit involved, implicit-function-type allow to something which was not quite possible before: abstract over implicit parameters.

type Constrained[T] = implicit (TC1[T], TC2[T], TC3[T]) => T

def f: Constrained[Int] = ???

Desugars into:

def f: Constrained[Int] = { ($e1: TC1[Int], $e2: TC2[Int], $e3: TC3[Int]) =>
  implicit $e4: TC1[Int] = $e1
  implicit $e5: TC1[Int] = $e2
  implicit $e6: TC1[Int] = $e3
  ???: Int
}

Getting back to our final encoding of mathematical expressions, Dotty implicit-function-type allow the encoding to be expanded with minimal syntactic overhead:

type Ring[T] = implicit (Exp[T], Mult[T]) => T

def tfm1[T]: Ring[T] = add(lit(7), neg(mul(lit(1), lit(2))))
def tfm2[T]: Ring[T] = mul(lit(7), tf1)

This concludes our revisiting of Tagless Final interpreters (see here for all the (Scala 2.11) snippets of this article)!

If you want to know more about Tagless Final interpreters, I strongly encourage you to continue with section 3 and 4 of the lecture notes for encodings of a simply typed lambda-calculus.

@bblfish
Copy link

bblfish commented Sep 15, 2019

Your example presents a final algebra.
Does this also work for coalgebras?
A widely used example for Tagless Final in the Scala community is

trait Console[T[_]] { 
   def readLn: T[String]  
   def writeLn(l: String): T[Unit]
}

The above is a coalgebra, as the results rely on a hidden state (the state of the world).
They are of the form X => F[X] rather than the algebraic F[X] => X.

One finds many other coalgebraic examples of tagless final in the Scala community,
Yet, the paper on which this is based only speaks of algebras. Is this a problem?

Looking around I found a 2019 article Codata in Action which shows how one can map algebras to final coalgebras (using the visitor pattern!) and how one can map those back. It states that codata is much more prevalent in the OO community.
Is final tagless on coalgebras perhaps just codata?

I asked a couple of questions on this topic on StackExchange:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment