Last active
July 12, 2023 20:48
-
-
Save dylech30th/5b3bcd36f05939e839309eca0aa0b594 to your computer and use it in GitHub Desktop.
The implementation of unification-based type inference algorithm in pure simply typed lambda calculus with Let Polymorphism
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* The implementation of unification-based type inference algorithm in simply typed lambda calculus with Let-Polymorphism | |
* | |
* The unification-based type inference algorithm is widely used in a huge variety of programming languages, where the most | |
* famous one is the Hindley-Milner Type System (a.k.a Damas-Milner Type System) of the ML-Family which permits the programmer | |
* to omit almost all of the type annotations, the algorithm is based on two concepts: Constraint Set and Unifier. | |
* | |
* A constraint set consist of several constraints, a constraints is basically a type equation, e.g., X = T, where both X and | |
* T are types | |
* A unifier is a set of type substitutions [X -> T1, Y -> T2, ...], it replaces all the type variables in its domain to the | |
* corresponding type variables in its co-domain when applied to a type T, yields [X -> T1, Y -> T2, ...]T, a unifier is said | |
* to be unifies a constraint if apply it to the both sides solves the equation, in such cases it is also called a solution | |
* for the equation; a unifier is said to be unifies a constraint set if it unifies every equation in it. | |
* | |
* The basic idea is, when we see something like λx.x 0, we know that x must be of type X -> Y for some X and Y, and since it | |
* applies to a natural number 0, it must also be of type Nat -> T for some T, where Nat stands for the type of natural numbers | |
* , thus, a constraint X -> Y = Nat -> T can be generated, and the substitutions [X -> Nat, Y -> T] unifies it. Our task is to | |
* calculate both constraint set and the unifier for a lambda term, the algorithm for finding the constraint set is simple, it | |
* is syntax-directed from the constraint typing relation (see Type and Programming Languages p.322), the only subtlety is the | |
* treatment for Let Polymorphism (see Type and Programming Languages p.331), where we choose the more efficient implementation | |
* over the naive approach (which has exponential time complexity). | |
* | |
* The unification algorithm to calculate the unifier for a specific constraint set uses a naive approach which can be found at | |
* Types and Programming Languages p.327, one thing needs to be mentioned is that we prefer to substitute actual types for type | |
* placeholders (we use Type.Forall to represent the placeholder types generated by the algorithm that is used to calculate the | |
* constraint set), the reason is clear: take the above example again, for equation X -> Y = Nat -> T, [Nat -> X, T -> Y] is also | |
* a solution, but it is meaningless since it generalize all types, and substitute T for Nat really breaks the soundness because | |
* 0 has type Nat instead of T, sometimes this rule is built into the unifier, where apply an unifier to a Nat always yields Nat | |
* again, regardless of any entry with the form Nat -> X, but we don't takes this method. | |
*/ | |
package ink.sora | |
import scala.:: | |
import scala.collection.mutable | |
trait Eq[T]: | |
def ==(left: T, right: T) : Boolean | |
end Eq | |
enum Type: | |
case Base(n: String) extends Type | |
case Constructor(from: Type, to: Type) extends Type | |
case Scheme(quantifiers: Set[Type], ty: Type) extends Type | |
case Forall(n: String) extends Type | |
override def toString: String = | |
import Type.* | |
this match | |
case Base(name) => name | |
case Constructor(from, to) => s"($from->$to)" | |
case Scheme(quantifiers, ty) => s"∀${quantifiers.foldLeft("")((acc, base) => acc + base.name)}.($ty)" | |
case Forall(name) => s"∀$name" | |
end toString | |
def name: String = | |
import Type.* | |
this match | |
case Base(name) => name | |
case Forall(name) => name | |
case _ => throw IllegalStateException() | |
end name | |
end Type | |
object TypePrimitives: | |
val Nat: Type = Type.Base("Nat") | |
val Bool: Type = Type.Base("Bool") | |
end TypePrimitives | |
enum SyntaxNode: | |
case Var(name: String) | |
case Abs(binder: String, binderType: Option[Type], body: SyntaxNode) | |
case App(function: SyntaxNode, applicant: SyntaxNode) | |
case Successor(nat: SyntaxNode) | |
case Predecessor(nat: SyntaxNode) | |
case IsZero(nat: SyntaxNode) | |
case If(condition: SyntaxNode, trueClause: SyntaxNode, falseClause: SyntaxNode) | |
case Zero, True, False | |
case Let(left: String, right: SyntaxNode, body: SyntaxNode) | |
case Seq(first: SyntaxNode, second: SyntaxNode) | |
end SyntaxNode | |
enum TypeSubstitution: | |
case Plain(from: Type, to: Type) | |
case Composition(outer: TypeSubstitution, inner: TypeSubstitution) | |
case Trivial | |
override def toString: String = | |
import TypeSubstitution.* | |
this match | |
case Plain(from, to) => s"$from/$to" | |
case Composition(outer, inner) => if outer.toString.nonEmpty then s"${outer.toString}(${inner.toString})" else inner.toString | |
case Trivial => "" | |
end toString | |
end TypeSubstitution | |
opaque type Constraint = (Type, Type) | |
opaque type ConstraintSet = (Type, Set[String], Set[Constraint]) | |
opaque type Bindings = mutable.Map[String, Type] | |
given Eq[Type] with | |
override def ==(left: Type, right: Type): Boolean = (left, right) match | |
case (Type.Base(leftName), Type.Base(rightName)) => leftName == rightName | |
case (Type.Constructor(leftFrom, leftTo), Type.Constructor(rightFrom, rightTo)) => ==(leftFrom, rightFrom) && ==(leftTo, rightTo) | |
case (Type.Forall(leftName), Type.Forall(rightName)) => leftName == rightName | |
case _ => false | |
end given | |
object FreshNameProvider: | |
private var currentChar: Char = 'A' | |
private var currentIndex: Int = 0 | |
def freshName(): String = | |
val str = s"$currentChar${if (currentIndex != 0) currentIndex else ""}" | |
currentChar = currentChar match | |
case x if 'A' until 'Z' contains x => (currentChar + 1).toChar | |
case _ => currentIndex += 1; 'A' | |
return str | |
def freshName(excludes: Set[String]): String = | |
val name = freshName() | |
return if !excludes.contains(name) then name else freshName(excludes) | |
end FreshNameProvider | |
// Get all the base type variables that are ever occurred in the given constraint set | |
def constraintVariables(constraintSet: Set[(Type, Type)]): Set[Type] = | |
return constraintSet.flatMap { | |
case (left, right) => typeVariables(left) | typeVariables(right) | |
} | |
end constraintVariables | |
// Get all the base type variables in a type declaration | |
def typeVariables(decl: Type): Set[Type] = | |
return decl match | |
case base: Type.Base => Set(base) | |
case Type.Constructor(from, to) => typeVariables(from) | typeVariables(to) | |
case Type.Scheme(_, ty) => typeVariables(ty) | |
case forall: Type.Forall => Set(forall) | |
end typeVariables | |
// Retrieves the free occurrences of variables in a term | |
def freeVariables(root: SyntaxNode): Set[String] = | |
import SyntaxNode.* | |
val symbols = mutable.Stack[String]() | |
def helper(r: SyntaxNode): Set[String] = | |
return r match | |
case Var(name) => if (!symbols.contains(name)) Set(name) else Set.empty | |
case Abs(binder, _, body) => | |
symbols.push(binder) | |
val fvs = helper(body) | |
symbols.pop() | |
return fvs | |
case App(function, applicant) => helper(function) | helper(applicant) | |
case If(condition, ifTrue, ifFalse) => helper(condition) | helper(ifTrue) | helper(ifFalse) | |
case _ => Set.empty | |
end helper | |
return helper(root) | |
end freeVariables | |
// Calculates the constraint set of a term according to the Constraint Typing Relation | |
def constraintsOf(root: SyntaxNode, symbols: Bindings = mutable.Map.empty)(using Eq[Type]): ConstraintSet = | |
import SyntaxNode.* | |
return root match | |
case Var(name) => | |
symbols(name) match | |
case Type.Scheme(quantifiers, ty) => | |
val newTypes: Iterable[Type.Forall] = (0 until quantifiers.size).map(_ => Type.Forall(FreshNameProvider.freshName())) | |
val sType = quantifiers.zip(newTypes).foldLeft(ty) { | |
case (acc, (quantifier, freshType)) => substitute(acc, quantifier, freshType) | |
} | |
(sType, newTypes.map(_.name).toSet, Set.empty) | |
case _ => (symbols(name), Set.empty, Set.empty) | |
case Abs(binder, binderType, body) => | |
if symbols.contains(binder) then | |
throw IllegalArgumentException(s"Duplicate symbol found: $binder") | |
else | |
binderType match | |
case Some(bTy) => | |
symbols(binder) = bTy | |
val (ty, names, constraints) = constraintsOf(body, symbols) | |
symbols -= binder | |
(Type.Constructor(bTy, ty), names, constraints) | |
case None => | |
val freshName = FreshNameProvider.freshName() | |
val assumption = Type.Forall(freshName) | |
symbols(binder) = assumption | |
val (ty, names, constraints) = constraintsOf(body, symbols) | |
symbols -= binder | |
(Type.Constructor(assumption, ty), names + freshName, constraints) | |
case App(function, applicant) => | |
val (ty1, names1, constraints1) = constraintsOf(function, symbols) | |
val (ty2, names2, constraints2) = constraintsOf(applicant, symbols) | |
if (names1 & names2).nonEmpty || (names1 & typeVariables(ty2).map(_.name)).nonEmpty || (names2 & typeVariables(ty1).map(_.name)).nonEmpty then | |
throw IllegalArgumentException("The names are not distinct in two operands of App") | |
else | |
val freshName = FreshNameProvider.freshName(names1 | names2 | typeVariables(ty1).map(_.name) | typeVariables(ty2).map(_.name) | constraintVariables(constraints1).map(_.name) | constraintVariables(constraints2).map(_.name) | symbols.keySet | freeVariables(function) | freeVariables(applicant)) | |
val assumption = Type.Forall(freshName) | |
(assumption, names1 | names2 + freshName, constraints1 | constraints2 + ((ty1, Type.Constructor(ty2, assumption)))) | |
case Zero => (TypePrimitives.Nat, Set.empty, Set.empty) | |
case arith if arith.isInstanceOf[Successor] || arith.isInstanceOf[Predecessor] => | |
val (ty, names, constraints) = constraintsOf(arith match { | |
case Successor(nat) => nat | |
case Predecessor(nat) => nat | |
case _ => throw IllegalArgumentException() | |
}, symbols) | |
(TypePrimitives.Nat, names, constraints + ((ty, TypePrimitives.Nat))) | |
case IsZero(nat) => | |
val (ty, names, constraints) = constraintsOf(nat, symbols) | |
(TypePrimitives.Bool, names, constraints + ((ty, TypePrimitives.Nat))) | |
case True | False => (TypePrimitives.Bool, Set.empty, Set.empty) | |
case If(condition, trueClause, falseClause) => | |
val (ty1, names1, constrains1) = constraintsOf(condition, symbols) | |
val (ty2, names2, constrains2) = constraintsOf(trueClause, symbols) | |
val (ty3, names3, constrains3) = constraintsOf(falseClause, symbols) | |
if (names1 & names2 & names3).nonEmpty then | |
throw IllegalArgumentException("The names are not distinct between operands of If") | |
(ty2, names1 | names2 | names3, constrains1 | constrains2 | constrains3 | Set((ty1, TypePrimitives.Bool), (ty2, ty3))) | |
case Let(left, right, body) => | |
val (ty, _, constraints) = constraintsOf(right, symbols) | |
val principalOfRight = substitute(unify(constraints), ty) | |
val remainingVars = typeVariables(principalOfRight).filter(_.isInstanceOf[Type.Forall]) // gather only those universally quantified types that are not presented in `symbols` | |
val generalizedType = if remainingVars.nonEmpty then Type.Scheme(remainingVars, principalOfRight) else principalOfRight | |
symbols(left) = generalizedType | |
val (ty1, names1, constraints1) = constraintsOf(body, symbols) | |
symbols -= left | |
(ty1, names1, constraints1) | |
case Seq(first, second) => | |
constraintsOf(first, symbols) // we require the first to be well-typed, however the result is trivial | |
constraintsOf(second, symbols) | |
case _ => throw IllegalArgumentException() | |
end constraintsOf | |
// Substitute all occurrences of `from` in `ty` to `to` | |
def substitute(ty: Type, from: Type, to: Type)(using Eq[Type]): Type = | |
return ty match | |
case Type.Constructor(fr, t) => Type.Constructor(substitute(fr, from, to), substitute(t, from, to)) | |
case t => if t == from then to else t | |
end substitute | |
// Substitute all occurrences of `from` in `constraints` to `to` | |
def substitute(constraints: Set[Constraint], from: Type, to: Type): Set[Constraint] = | |
def helper(cs: List[Constraint]): List[Constraint] = | |
return cs match | |
case (fr, t) :: tail => (substitute(fr, from, to), substitute(t, from, to)) :: helper(tail) | |
case Nil => Nil | |
return helper(constraints.toList).toSet | |
end substitute | |
// Substitute a type according to the give TypeSubstitution | |
def substitute(substitution: TypeSubstitution, ty: Type): Type = | |
import TypeSubstitution.* | |
return substitution match | |
case Plain(from, to) => substitute(ty, from, to) | |
case Composition(outer, inner) => substitute(outer, substitute(inner, ty)) | |
case Trivial => ty | |
end substitute | |
// Calculates the unifier for a given constraint set | |
def unify(constraints: Set[Constraint])(using Eq[Type]): TypeSubstitution = | |
import TypeSubstitution.* | |
def helper(cs: List[Constraint])(using Eq[Type]): TypeSubstitution = | |
return cs match | |
case (left, right) :: tail => | |
(left, right) match | |
case (left, right) if left == right => | |
helper(tail) | |
case (left: Type.Forall, right) => // we prefer to substitute the real type for universal quantifiers | |
Composition(helper(substitute(tail.toSet, left, right).toList), Plain(left, right)) | |
case (left, right: Type.Forall) => | |
Composition(helper(substitute(tail.toSet, right, left).toList), Plain(right, left)) | |
case (leftBase @ Type.Base(leftName), _) if !typeVariables(right).exists(_.name == leftName) => | |
Composition(helper(substitute(tail.toSet, leftBase, right).toList), Plain(leftBase, right)) | |
case (_, rightBase @ Type.Base(rightName)) if !typeVariables(left).exists(_.name == rightName) => | |
Composition(helper(substitute(tail.toSet, rightBase, left).toList), Plain(rightBase, left)) | |
case (Type.Constructor(leftFrom, leftTo), Type.Constructor(rightFrom, rightTo)) => | |
helper(tail ::: ((leftFrom, rightFrom) :: (leftTo, rightTo) :: Nil)) | |
case _ => | |
throw new IllegalArgumentException("Unification failed: Cannot find a proper solution, The constraint set is not unifiable") | |
case Nil => TypeSubstitution.Trivial | |
return helper(constraints.toList) | |
end unify | |
@main | |
def main(): Unit = | |
// λx:X->Y.x 0 | |
val (explicitlyTyped, _, eConstraints) = constraintsOf(SyntaxNode.Abs( | |
"x", Some(Type.Constructor(Type.Base("X"), Type.Base("Y"))), SyntaxNode.App(SyntaxNode.Var("x"), SyntaxNode.Zero) | |
)) | |
// λx.x 0 | |
println(substitute(unify(eConstraints), explicitlyTyped)) | |
val (implicitlyTyped, _, iConstraints) = constraintsOf(SyntaxNode.Abs( | |
"x", None, SyntaxNode.App(SyntaxNode.Var("x"), SyntaxNode.Zero) | |
)) | |
println(substitute(unify(iConstraints), implicitlyTyped)) | |
// let double = λf.λa.f(f(a)) in | |
// let a = double (λx:Nat. succ (succ x)) 11 in | |
// let b = double (λx:Bool.x) false in | |
// a; b | |
val syntaxNode = SyntaxNode.Let("double", | |
SyntaxNode.Abs("f", None, | |
SyntaxNode.Abs("a", None, | |
SyntaxNode.App(SyntaxNode.Var("f"), SyntaxNode.App(SyntaxNode.Var("f"), SyntaxNode.Var("a"))))), | |
SyntaxNode.Let("a", SyntaxNode.App(SyntaxNode.App(SyntaxNode.Var("double"), | |
SyntaxNode.Abs("x", Some(TypePrimitives.Nat), SyntaxNode.Successor(SyntaxNode.Successor(SyntaxNode.Var("x"))))), SyntaxNode.Successor(SyntaxNode.Zero)), | |
SyntaxNode.Let("b", SyntaxNode.App(SyntaxNode.App(SyntaxNode.Var("double"), SyntaxNode.Abs("x", Some(TypePrimitives.Bool), SyntaxNode.Var("x"))), SyntaxNode.False), | |
SyntaxNode.Seq(SyntaxNode.Var("a"), SyntaxNode.Var("b"))))) | |
val (ty, _, constraints) = constraintsOf(syntaxNode) | |
println(ty) | |
end main |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment