Skip to content

Instantly share code, notes, and snippets.

@polytypic
Last active December 8, 2020 18:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save polytypic/fd85880ff1081acac308084b619a3178 to your computer and use it in GitHub Desktop.
Save polytypic/fd85880ff1081acac308084b619a3178 to your computer and use it in GitHub Desktop.
Curious case of GADTs in Scala
// Below is a simple attempt at using GADTs in Scala based on
//
// https://github.com/palladin/idris-snippets/blob/master/src/HOAS.idr
object Hoas {
// With case classes one can directly write down what looks like a GADT:
sealed trait Expr[A]
final case class Val[A](value: A) extends Expr[A]
final case class Bin[A, B, C](bin: (A, B) => C, lhs: Expr[A], rhs: Expr[B])
extends Expr[C]
final case class If[A](
condition: Expr[Boolean],
onTrue: Expr[A],
onFalse: Expr[A]
) extends Expr[A]
final case class App[A, B](function: Expr[A => B], argument: Expr[A])
extends Expr[B]
final case class Lam[A, B](lambda: Expr[A] => Expr[B]) extends Expr[A => B]
final case class Fix[A, B](fix: Expr[(A => B) => A => B]) extends Expr[A => B]
// And here is the factorial function:
val fact: Expr[Int => Int] = Fix(
Lam((f: Expr[Int => Int]) =>
Lam((x: Expr[Int]) =>
If(
Bin((_: Int) == (_: Int), x, Val(0)),
Val(1),
Bin(
(_: Int) * (_: Int),
x,
App(f, Bin((_: Int) - (_: Int), x, Val(1)))
)
)
)
)
)
// This `eval` function also superficially seems fine:
def eval[T](expr: Expr[T]): T = expr match {
case Bin(f, x, y) => f(eval(x), eval(y))
case If(b, c, a) => eval(if (eval(b)) c else a)
case App(f, x) => eval(f)(eval(x))
case Lam(f) => x => eval(f(Val(x)))
case Fix(e) => {
val f = eval(e)
def rec(x: Any): Any = f(rec(_))(x)
rec(_)
}
case Val(x) => x
}
// And one even gets the expected result:
val one_hundred_twenty: Int = eval(App(fact, Val(5)))
// The interesting thing, however, is that the types inside the `eval`
// function are not what one might expect. For example, the type of `f` is
// `(Any, Any) => T` and the type of `x` and `y` is `Any`. This means that
// the expression `f(eval(x), eval(y))` is not typed precisely. If one would
// introduce a bug by flipping the arguments to `f(eval(y), eval(x))` the
// code would still pass the type checker. This doesn't happen in languages
// with proper support for GADTs.
//
// First of all, I think that this is bad. Writing code in a
// straightforward manner just doesn't give you the expected guarantees.
//
// Second of all, how should we implement this example in Scala 2 in a type
// safe manner? (Asking, because I don't yet know how.)
}
// This is my second attempt to encode GADTs in Scala. The example is based on
//
// https://github.com/palladin/idris-snippets/blob/master/src/HOAS.idr
//
// and the encoding technique is inspired by the paper
//
// GADTs for the OCaml Masses
// http://homepage.cs.uiowa.edu/~astump/papers/icfp09.pdf
//
// and is basically a Scott encoding.
object HoasScott {
// We need the identity type constructor a bit later...
type Id[T] = T
// The `ExprDestructor` trait defines the cases of an expression. The
// parameter `Result` is the type constructor for the result of
// destructuring an expression.
trait ExprDestructor[Result[_]] {
def Val[A](value: A): Result[A]
def Bin[A, B, C](bin: (A, B) => C, lhs: Expr[A], rhs: Expr[B]): Result[C]
def If[A](
condition: Expr[Boolean],
onTrue: Expr[A],
onFalse: Expr[A]
): Result[A]
def App[A, B](function: Expr[A => B], argument: Expr[A]): Result[B]
def Lam[A, B](lambda: Expr[A] => Expr[B]): Result[A => B]
def Fix[A, B](fix: Expr[(A => B) => A => B]): Result[A => B]
}
// The `Expr` trait is the actual type of expressions and can be called to
// destructure the expression.
trait Expr[T] {
def apply[Result[_]](handler: ExprDestructor[Result]): Result[T]
}
// The `Expr` object implements the `ExprDestructor` for the `Expr` trait
// itself to construct `Expr` values.
object Expr extends ExprDestructor[Expr] {
def Val[A](value: A) = new Expr[A] {
def apply[Result[_]](handler: ExprDestructor[Result]) = handler.Val(value)
}
def Bin[A, B, C](bin: (A, B) => C, lhs: Expr[A], rhs: Expr[B]) =
new Expr[C] {
def apply[Result[_]](handler: ExprDestructor[Result]) =
handler.Bin(bin, lhs, rhs)
}
def If[A](
condition: Expr[Boolean],
onTrue: Expr[A],
onFalse: Expr[A]
) = new Expr[A] {
def apply[Result[_]](handler: ExprDestructor[Result]) =
handler.If(condition, onTrue, onFalse)
}
def App[A, B](function: Expr[A => B], argument: Expr[A]) = new Expr[B] {
def apply[Result[_]](handler: ExprDestructor[Result]) =
handler.App(function, argument)
}
def Lam[A, B](lambda: Expr[A] => Expr[B]) = new Expr[A => B] {
def apply[Result[_]](handler: ExprDestructor[Result]) =
handler.Lam(lambda)
}
def Fix[A, B](fix: Expr[(A => B) => A => B]) = new Expr[A => B] {
def apply[Result[_]](handler: ExprDestructor[Result]) = handler.Fix(fix)
}
}
import Expr._
// The factorial function expression.
val fact: Expr[Int => Int] = Fix(
Lam((f: Expr[Int => Int]) =>
Lam((x: Expr[Int]) =>
If(
Bin((_: Int) == (_: Int), x, Val(0)),
Val(1),
Bin(
(_: Int) * (_: Int),
x,
App(f, Bin((_: Int) - (_: Int), x, Val(1)))
)
)
)
)
)
// Typed interpreter for typed expressions.
def eval[T](expr: Expr[T]): T =
expr(new ExprDestructor[Id] {
def Val[A](value: A) = value
def Bin[A, B, C](bin: (A, B) => C, lhs: Expr[A], rhs: Expr[B]) =
bin(eval(lhs), eval(rhs))
def If[A](
condition: Expr[Boolean],
onTrue: Expr[A],
onFalse: Expr[A]
) = eval(if (eval(condition)) onTrue else onFalse)
def App[A, B](function: Expr[A => B], argument: Expr[A]) =
eval(function)(eval(argument))
def Lam[A, B](lambda: Expr[A] => Expr[B]) =
(argument: A) => eval(lambda(Expr.Val(argument)))
def Fix[A, B](fix: Expr[(A => B) => A => B]) = {
val fn = eval(fix)
def rec(x: A): B = fn(rec(_))(x)
rec(_)
}
})
// This evaluates to 120.
val one_hundred_twenty: Int = eval(App(fact, Val(5)))
// While this encoding basically works, it has several downsides:
// - Some amount of boilerplate is required to define a GADT this way.
// - Destructuring does not directly support nesting (or other nice pattern
// matching features).
// - Impossible cases are not filtered out.
// - The encoding requires constructing an object per constructor.
// - Destructuring requires constructing an object per call.
//
// Are there better ways to encode GADTs in Scala?
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment