Skip to content

Instantly share code, notes, and snippets.

@dylech30th
Last active July 12, 2023 20:48
Show Gist options
  • Save dylech30th/5b3bcd36f05939e839309eca0aa0b594 to your computer and use it in GitHub Desktop.
Save dylech30th/5b3bcd36f05939e839309eca0aa0b594 to your computer and use it in GitHub Desktop.
The implementation of unification-based type inference algorithm in pure simply typed lambda calculus with Let Polymorphism
/**
* The implementation of unification-based type inference algorithm in simply typed lambda calculus with Let-Polymorphism
*
* The unification-based type inference algorithm is widely used in a huge variety of programming languages, where the most
* famous one is the Hindley-Milner Type System (a.k.a Damas-Milner Type System) of the ML-Family which permits the programmer
* to omit almost all of the type annotations, the algorithm is based on two concepts: Constraint Set and Unifier.
*
* A constraint set consist of several constraints, a constraints is basically a type equation, e.g., X = T, where both X and
* T are types
* A unifier is a set of type substitutions [X -> T1, Y -> T2, ...], it replaces all the type variables in its domain to the
* corresponding type variables in its co-domain when applied to a type T, yields [X -> T1, Y -> T2, ...]T, a unifier is said
* to be unifies a constraint if apply it to the both sides solves the equation, in such cases it is also called a solution
* for the equation; a unifier is said to be unifies a constraint set if it unifies every equation in it.
*
* The basic idea is, when we see something like λx.x 0, we know that x must be of type X -> Y for some X and Y, and since it
* applies to a natural number 0, it must also be of type Nat -> T for some T, where Nat stands for the type of natural numbers
* , thus, a constraint X -> Y = Nat -> T can be generated, and the substitutions [X -> Nat, Y -> T] unifies it. Our task is to
* calculate both constraint set and the unifier for a lambda term, the algorithm for finding the constraint set is simple, it
* is syntax-directed from the constraint typing relation (see Type and Programming Languages p.322), the only subtlety is the
* treatment for Let Polymorphism (see Type and Programming Languages p.331), where we choose the more efficient implementation
* over the naive approach (which has exponential time complexity).
*
* The unification algorithm to calculate the unifier for a specific constraint set uses a naive approach which can be found at
* Types and Programming Languages p.327, one thing needs to be mentioned is that we prefer to substitute actual types for type
* placeholders (we use Type.Forall to represent the placeholder types generated by the algorithm that is used to calculate the
* constraint set), the reason is clear: take the above example again, for equation X -> Y = Nat -> T, [Nat -> X, T -> Y] is also
* a solution, but it is meaningless since it generalize all types, and substitute T for Nat really breaks the soundness because
* 0 has type Nat instead of T, sometimes this rule is built into the unifier, where apply an unifier to a Nat always yields Nat
* again, regardless of any entry with the form Nat -> X, but we don't takes this method.
*/
package ink.sora
import scala.::
import scala.collection.mutable
trait Eq[T]:
def ==(left: T, right: T) : Boolean
end Eq
enum Type:
case Base(n: String) extends Type
case Constructor(from: Type, to: Type) extends Type
case Scheme(quantifiers: Set[Type], ty: Type) extends Type
case Forall(n: String) extends Type
override def toString: String =
import Type.*
this match
case Base(name) => name
case Constructor(from, to) => s"($from->$to)"
case Scheme(quantifiers, ty) => s"∀${quantifiers.foldLeft("")((acc, base) => acc + base.name)}.($ty)"
case Forall(name) => s"∀$name"
end toString
def name: String =
import Type.*
this match
case Base(name) => name
case Forall(name) => name
case _ => throw IllegalStateException()
end name
end Type
object TypePrimitives:
val Nat: Type = Type.Base("Nat")
val Bool: Type = Type.Base("Bool")
end TypePrimitives
enum SyntaxNode:
case Var(name: String)
case Abs(binder: String, binderType: Option[Type], body: SyntaxNode)
case App(function: SyntaxNode, applicant: SyntaxNode)
case Successor(nat: SyntaxNode)
case Predecessor(nat: SyntaxNode)
case IsZero(nat: SyntaxNode)
case If(condition: SyntaxNode, trueClause: SyntaxNode, falseClause: SyntaxNode)
case Zero, True, False
case Let(left: String, right: SyntaxNode, body: SyntaxNode)
case Seq(first: SyntaxNode, second: SyntaxNode)
end SyntaxNode
enum TypeSubstitution:
case Plain(from: Type, to: Type)
case Composition(outer: TypeSubstitution, inner: TypeSubstitution)
case Trivial
override def toString: String =
import TypeSubstitution.*
this match
case Plain(from, to) => s"$from/$to"
case Composition(outer, inner) => if outer.toString.nonEmpty then s"${outer.toString}(${inner.toString})" else inner.toString
case Trivial => ""
end toString
end TypeSubstitution
opaque type Constraint = (Type, Type)
opaque type ConstraintSet = (Type, Set[String], Set[Constraint])
opaque type Bindings = mutable.Map[String, Type]
given Eq[Type] with
override def ==(left: Type, right: Type): Boolean = (left, right) match
case (Type.Base(leftName), Type.Base(rightName)) => leftName == rightName
case (Type.Constructor(leftFrom, leftTo), Type.Constructor(rightFrom, rightTo)) => ==(leftFrom, rightFrom) && ==(leftTo, rightTo)
case (Type.Forall(leftName), Type.Forall(rightName)) => leftName == rightName
case _ => false
end given
object FreshNameProvider:
private var currentChar: Char = 'A'
private var currentIndex: Int = 0
def freshName(): String =
val str = s"$currentChar${if (currentIndex != 0) currentIndex else ""}"
currentChar = currentChar match
case x if 'A' until 'Z' contains x => (currentChar + 1).toChar
case _ => currentIndex += 1; 'A'
return str
def freshName(excludes: Set[String]): String =
val name = freshName()
return if !excludes.contains(name) then name else freshName(excludes)
end FreshNameProvider
// Get all the base type variables that are ever occurred in the given constraint set
def constraintVariables(constraintSet: Set[(Type, Type)]): Set[Type] =
return constraintSet.flatMap {
case (left, right) => typeVariables(left) | typeVariables(right)
}
end constraintVariables
// Get all the base type variables in a type declaration
def typeVariables(decl: Type): Set[Type] =
return decl match
case base: Type.Base => Set(base)
case Type.Constructor(from, to) => typeVariables(from) | typeVariables(to)
case Type.Scheme(_, ty) => typeVariables(ty)
case forall: Type.Forall => Set(forall)
end typeVariables
// Retrieves the free occurrences of variables in a term
def freeVariables(root: SyntaxNode): Set[String] =
import SyntaxNode.*
val symbols = mutable.Stack[String]()
def helper(r: SyntaxNode): Set[String] =
return r match
case Var(name) => if (!symbols.contains(name)) Set(name) else Set.empty
case Abs(binder, _, body) =>
symbols.push(binder)
val fvs = helper(body)
symbols.pop()
return fvs
case App(function, applicant) => helper(function) | helper(applicant)
case If(condition, ifTrue, ifFalse) => helper(condition) | helper(ifTrue) | helper(ifFalse)
case _ => Set.empty
end helper
return helper(root)
end freeVariables
// Calculates the constraint set of a term according to the Constraint Typing Relation
def constraintsOf(root: SyntaxNode, symbols: Bindings = mutable.Map.empty)(using Eq[Type]): ConstraintSet =
import SyntaxNode.*
return root match
case Var(name) =>
symbols(name) match
case Type.Scheme(quantifiers, ty) =>
val newTypes: Iterable[Type.Forall] = (0 until quantifiers.size).map(_ => Type.Forall(FreshNameProvider.freshName()))
val sType = quantifiers.zip(newTypes).foldLeft(ty) {
case (acc, (quantifier, freshType)) => substitute(acc, quantifier, freshType)
}
(sType, newTypes.map(_.name).toSet, Set.empty)
case _ => (symbols(name), Set.empty, Set.empty)
case Abs(binder, binderType, body) =>
if symbols.contains(binder) then
throw IllegalArgumentException(s"Duplicate symbol found: $binder")
else
binderType match
case Some(bTy) =>
symbols(binder) = bTy
val (ty, names, constraints) = constraintsOf(body, symbols)
symbols -= binder
(Type.Constructor(bTy, ty), names, constraints)
case None =>
val freshName = FreshNameProvider.freshName()
val assumption = Type.Forall(freshName)
symbols(binder) = assumption
val (ty, names, constraints) = constraintsOf(body, symbols)
symbols -= binder
(Type.Constructor(assumption, ty), names + freshName, constraints)
case App(function, applicant) =>
val (ty1, names1, constraints1) = constraintsOf(function, symbols)
val (ty2, names2, constraints2) = constraintsOf(applicant, symbols)
if (names1 & names2).nonEmpty || (names1 & typeVariables(ty2).map(_.name)).nonEmpty || (names2 & typeVariables(ty1).map(_.name)).nonEmpty then
throw IllegalArgumentException("The names are not distinct in two operands of App")
else
val freshName = FreshNameProvider.freshName(names1 | names2 | typeVariables(ty1).map(_.name) | typeVariables(ty2).map(_.name) | constraintVariables(constraints1).map(_.name) | constraintVariables(constraints2).map(_.name) | symbols.keySet | freeVariables(function) | freeVariables(applicant))
val assumption = Type.Forall(freshName)
(assumption, names1 | names2 + freshName, constraints1 | constraints2 + ((ty1, Type.Constructor(ty2, assumption))))
case Zero => (TypePrimitives.Nat, Set.empty, Set.empty)
case arith if arith.isInstanceOf[Successor] || arith.isInstanceOf[Predecessor] =>
val (ty, names, constraints) = constraintsOf(arith match {
case Successor(nat) => nat
case Predecessor(nat) => nat
case _ => throw IllegalArgumentException()
}, symbols)
(TypePrimitives.Nat, names, constraints + ((ty, TypePrimitives.Nat)))
case IsZero(nat) =>
val (ty, names, constraints) = constraintsOf(nat, symbols)
(TypePrimitives.Bool, names, constraints + ((ty, TypePrimitives.Nat)))
case True | False => (TypePrimitives.Bool, Set.empty, Set.empty)
case If(condition, trueClause, falseClause) =>
val (ty1, names1, constrains1) = constraintsOf(condition, symbols)
val (ty2, names2, constrains2) = constraintsOf(trueClause, symbols)
val (ty3, names3, constrains3) = constraintsOf(falseClause, symbols)
if (names1 & names2 & names3).nonEmpty then
throw IllegalArgumentException("The names are not distinct between operands of If")
(ty2, names1 | names2 | names3, constrains1 | constrains2 | constrains3 | Set((ty1, TypePrimitives.Bool), (ty2, ty3)))
case Let(left, right, body) =>
val (ty, _, constraints) = constraintsOf(right, symbols)
val principalOfRight = substitute(unify(constraints), ty)
val remainingVars = typeVariables(principalOfRight).filter(_.isInstanceOf[Type.Forall]) // gather only those universally quantified types that are not presented in `symbols`
val generalizedType = if remainingVars.nonEmpty then Type.Scheme(remainingVars, principalOfRight) else principalOfRight
symbols(left) = generalizedType
val (ty1, names1, constraints1) = constraintsOf(body, symbols)
symbols -= left
(ty1, names1, constraints1)
case Seq(first, second) =>
constraintsOf(first, symbols) // we require the first to be well-typed, however the result is trivial
constraintsOf(second, symbols)
case _ => throw IllegalArgumentException()
end constraintsOf
// Substitute all occurrences of `from` in `ty` to `to`
def substitute(ty: Type, from: Type, to: Type)(using Eq[Type]): Type =
return ty match
case Type.Constructor(fr, t) => Type.Constructor(substitute(fr, from, to), substitute(t, from, to))
case t => if t == from then to else t
end substitute
// Substitute all occurrences of `from` in `constraints` to `to`
def substitute(constraints: Set[Constraint], from: Type, to: Type): Set[Constraint] =
def helper(cs: List[Constraint]): List[Constraint] =
return cs match
case (fr, t) :: tail => (substitute(fr, from, to), substitute(t, from, to)) :: helper(tail)
case Nil => Nil
return helper(constraints.toList).toSet
end substitute
// Substitute a type according to the give TypeSubstitution
def substitute(substitution: TypeSubstitution, ty: Type): Type =
import TypeSubstitution.*
return substitution match
case Plain(from, to) => substitute(ty, from, to)
case Composition(outer, inner) => substitute(outer, substitute(inner, ty))
case Trivial => ty
end substitute
// Calculates the unifier for a given constraint set
def unify(constraints: Set[Constraint])(using Eq[Type]): TypeSubstitution =
import TypeSubstitution.*
def helper(cs: List[Constraint])(using Eq[Type]): TypeSubstitution =
return cs match
case (left, right) :: tail =>
(left, right) match
case (left, right) if left == right =>
helper(tail)
case (left: Type.Forall, right) => // we prefer to substitute the real type for universal quantifiers
Composition(helper(substitute(tail.toSet, left, right).toList), Plain(left, right))
case (left, right: Type.Forall) =>
Composition(helper(substitute(tail.toSet, right, left).toList), Plain(right, left))
case (leftBase @ Type.Base(leftName), _) if !typeVariables(right).exists(_.name == leftName) =>
Composition(helper(substitute(tail.toSet, leftBase, right).toList), Plain(leftBase, right))
case (_, rightBase @ Type.Base(rightName)) if !typeVariables(left).exists(_.name == rightName) =>
Composition(helper(substitute(tail.toSet, rightBase, left).toList), Plain(rightBase, left))
case (Type.Constructor(leftFrom, leftTo), Type.Constructor(rightFrom, rightTo)) =>
helper(tail ::: ((leftFrom, rightFrom) :: (leftTo, rightTo) :: Nil))
case _ =>
throw new IllegalArgumentException("Unification failed: Cannot find a proper solution, The constraint set is not unifiable")
case Nil => TypeSubstitution.Trivial
return helper(constraints.toList)
end unify
@main
def main(): Unit =
// λx:X->Y.x 0
val (explicitlyTyped, _, eConstraints) = constraintsOf(SyntaxNode.Abs(
"x", Some(Type.Constructor(Type.Base("X"), Type.Base("Y"))), SyntaxNode.App(SyntaxNode.Var("x"), SyntaxNode.Zero)
))
// λx.x 0
println(substitute(unify(eConstraints), explicitlyTyped))
val (implicitlyTyped, _, iConstraints) = constraintsOf(SyntaxNode.Abs(
"x", None, SyntaxNode.App(SyntaxNode.Var("x"), SyntaxNode.Zero)
))
println(substitute(unify(iConstraints), implicitlyTyped))
// let double = λf.λa.f(f(a)) in
// let a = double (λx:Nat. succ (succ x)) 11 in
// let b = double (λx:Bool.x) false in
// a; b
val syntaxNode = SyntaxNode.Let("double",
SyntaxNode.Abs("f", None,
SyntaxNode.Abs("a", None,
SyntaxNode.App(SyntaxNode.Var("f"), SyntaxNode.App(SyntaxNode.Var("f"), SyntaxNode.Var("a"))))),
SyntaxNode.Let("a", SyntaxNode.App(SyntaxNode.App(SyntaxNode.Var("double"),
SyntaxNode.Abs("x", Some(TypePrimitives.Nat), SyntaxNode.Successor(SyntaxNode.Successor(SyntaxNode.Var("x"))))), SyntaxNode.Successor(SyntaxNode.Zero)),
SyntaxNode.Let("b", SyntaxNode.App(SyntaxNode.App(SyntaxNode.Var("double"), SyntaxNode.Abs("x", Some(TypePrimitives.Bool), SyntaxNode.Var("x"))), SyntaxNode.False),
SyntaxNode.Seq(SyntaxNode.Var("a"), SyntaxNode.Var("b")))))
val (ty, _, constraints) = constraintsOf(syntaxNode)
println(ty)
end main
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment