Skip to content

Instantly share code, notes, and snippets.

@folone
Last active December 17, 2015 01:49
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save folone/5530838 to your computer and use it in GitHub Desktop.
Save folone/5530838 to your computer and use it in GitHub Desktop.
/*
* http://dysphoria.net/code/hindley-milner/HindleyMilner.scala
* Andrew Forrest
*
* Implementation of basic polymorphic type-checking for a simple language.
* Based heavily on Nikita Borisov’s Perl implementation at
* http://web.archive.org/web/20050420002559/www.cs.berkeley.edu/~nikitab/courses/cs263/hm.html
* which in turn is based on the paper by Luca Cardelli at
* http://lucacardelli.name/Papers/BasicTypechecking.pdf
*
* If you run it with "scala HindleyMilner.scala" it will attempt to report the types
* for a few example expressions. (It uses UTF-8 for output, so you may need to set your
* terminal accordingly.)
*
* Do with it what you will.
*/
abstract class SyntaxNode
case class Lambda(v: String, body: SyntaxNode) extends SyntaxNode
case class Ident(name: String) extends SyntaxNode
case class Apply(fn: SyntaxNode, arg: SyntaxNode) extends SyntaxNode
case class Let(v: String, defn: SyntaxNode, body: SyntaxNode) extends SyntaxNode
case class Letrec(v: String, defn: SyntaxNode, body: SyntaxNode) extends SyntaxNode
object SyntaxNode {
def string(ast: SyntaxNode): String = {
if (ast.isInstanceOf[Ident])
nakedString(ast)
else
"("+nakedString(ast)+")"
}
def nakedString(ast: SyntaxNode) = ast match {
case i: Ident ⇒ i.name
case l: Λ ⇒ "fn " + l.v + " ⇒ " + string(l.body)
case f: Apply ⇒ string(f.fn) + " " + string(f.arg)
case l: Let ⇒ "let " + l.v + " = " + string(l.defn) + " in " + string(l.body)
case l: Letrec ⇒ "letrec " + l.v + " = " + string(l.defn) + " in " + string(l.body)
}
}
class TypeError(msg: String) extends Exception(msg)
class ParseError(msg: String) extends Exception(msg)
object TypeSystem {
type Env = Map[String, Type]
abstract class Type
case class Variable(id: Int) extends Type {
var instance: Option[Type] = None
lazy val name = nextUniqueName
}
case class Oper(name: String, args: Seq[Type]) extends Type
def Function(from: Type, to: Type) = Oper("→", Array(from, to))
val Integer = Oper("int", Array())
val Bool = Oper("bool", Array())
var _nextVariableName = 'α';
def nextUniqueName = {
val result = _nextVariableName
_nextVariableName = (_nextVariableName.toInt + 1).toChar
result.toString
}
var _nextVariableId = 0
def newVariable: Variable = {
val result = _nextVariableId
_nextVariableId += 1
Variable(result)
}
def string(t: Type): String = t match {
case v: Variable ⇒ v.instance match {
case Some(i) ⇒ string(i)
case None ⇒ v.name
}
case Oper(name, args) ⇒ {
if (args.length == 0)
name
else if (args.length == 2)
"("+string(args(0))+" "+name+" "+string(args(1))+")"
else
args.mkString(name + " ", " ", "")
}
}
def analyse(ast: SyntaxNode, env: Env): Type = analyse(ast, env, Set.<∅>)
def analyse(ast: SyntaxNode, env: Env, nongen: Set[Variable]): Type = ast match {
case Ident(name) ⇒ gettype(name, env, nongen)
case Apply(fn, arg) ⇒ {
val funtype = analyse(fn, env, nongen)
val argtype = analyse(arg, env, nongen)
val resulttype = newVariable
unify(Function(argtype, resulttype), funtype)
resulttype
}
case Lambda(arg, body) ⇒ {
val argtype = newVariable
val resulttype = analyse(body,
env + (arg → argtype),
nongen + argtype)
Function(argtype, resulttype)
}
case Let(v, defn, body) ⇒ {
val defntype = analyse(defn, env, nongen)
val newenv = env + (v → defntype)
analyse(body, newenv, nongen)
}
case Letrec(v, defn, body) ⇒ {
val newtype = newVariable
val newenv = env + (v → newtype)
val defntype = analyse(defn, newenv, nongen + newtype)
unify(newtype, defntype)
analyse(body, newenv, nongen)
}
}
def gettype(name: String, env: Env, nongen: Set[Variable]): Type = {
if (env.contains(name))
fresh(env(name), nongen)
else if (isIntegerLiteral(name))
Integer
else
throw new ParseError("Undefined symbol "+name)
}
def fresh(t: Type, nongen: Set[Variable]) = {
import scala.collection.mutable
val mappings = new mutable.HashMap[Variable, Variable]
def freshrec(tp: Type): Type = {
prune(tp) match {
case v: Variable ⇒
if (isgeneric(v, nongen))
mappings.getOrElseUpdate(v, newVariable)
else
v
case Oper(name, args) ⇒
Oper(name, args.map(freshrec(_)))
}
}
freshrec(t)
}
def unify(t1: Type, t2: Type) {
val type1 = prune(t1)
val type2 = prune(t2)
(type1, type2) match {
case (a: Variable, b) ⇒ if (a != b) {
if (occursintype(a, b))
throw new TypeError("recursive unification")
a.instance = Some(b)
}
case (a: Oper, b: Variable) ⇒ unify(b, a)
case (a: Oper, b: Oper) ⇒ {
if (a.name != b.name ||
a.args.length != b.args.length) throw new TypeError("Type mismatch: "+string(a)+"≠"+string(b))
for(i ← 0 until a.args.length)
unify(a.args(i), b.args(i))
}
}
}
// Returns the currently defining instance of t.
// As a side effect, collapses the list of type instances.
def prune(t: Type): Type = t match {
case v: Variable if v.instance.isDefined ⇒ {
var inst = prune(v.instance.get)
v.instance = Some(inst)
inst
}
case _ ⇒ t
}
// Note: must be called with v 'pre-pruned'
def isgeneric(v: Variable, nongen: Set[Variable]) = !(occursin(v, nongen))
// Note: must be called with v 'pre-pruned'
def occursintype(v: Variable, type2: Type): Boolean = {
prune(type2) match {
case `v` ⇒ true
case Oper(name, args) ⇒ occursin(v, args)
case _ ⇒ false
}
}
def occursin(t: Variable, list: Iterable[Type]) =
list exists (t2 ⇒ occursintype(t, t2))
val checkDigits = "^(\\d+)$".r
def isIntegerLiteral(name: String) = checkDigits.findFirstIn(name).isDefined
}
object HindleyMilner {
def main(args: Array[String]){
Console.setOut(new java.io.PrintStream(Console.out, true, "utf-8"))
val var1 = TypeSystem.newVariable
val var2 = TypeSystem.newVariable
val pairtype = TypeSystem.Oper("×", Array(var1, var2))
val var3 = TypeSystem.newVariable
val myenv: TypeSystem.Env = Map.<∅> ++ Array(
"pair" → TypeSystem.Function(var1, TypeSystem.Function(var2, pairtype)),
"true" → TypeSystem.Bool,
"cond" → TypeSystem.Function(TypeSystem.Bool, TypeSystem.Function(var3, TypeSystem.Function(var3, var3))),
"∅" → TypeSystem.Function(TypeSystem.Integer, TypeSystem.Bool),
"pred" → TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer),
"times"-> TypeSystem.Function(TypeSystem.Integer, TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer))
)
val pair = Apply(Apply(Ident("pair"), Apply(Ident("f"), Ident("4"))), Apply(Ident("f"), Ident("true")))
val examples = Array[SyntaxNode](
// factorial
Letrec("factorial", // letrec factorial =
Lambda("n", // fn n ⇒
Apply(
Apply( // cond (∅ n) 1
Apply(Ident("cond"), // cond (∅ n)
Apply(Ident("∅"), Ident("n"))),
Ident("1")),
Apply( // times n
Apply(Ident("times"), Ident("n")),
Apply(Ident("factorial"),
Apply(Ident("pred"), Ident("n")))
)
)
), // in
Apply(Ident("factorial"), Ident("5"))
),
// Should fail:
// fn x ⇒ (pair(x(3) (x(true)))
Lambda("x",
Apply(
Apply(Ident("pair"),
Apply(Ident("x"), Ident("3"))),
Apply(Ident("x"), Ident("true")))),
// pair(f(3), f(true))
Apply(
Apply(Ident("pair"), Apply(Ident("f"), Ident("4"))),
Apply(Ident("f"), Ident("true"))),
// letrec f = (fn x ⇒ x) in ((pair (f 4)) (f true))
Let("f", Lambda("x", Ident("x")), pair),
// fn f ⇒ f f (fail)
Lambda("f", Apply(Ident("f"), Ident("f"))),
// let g = fn f ⇒ 5 in g g
Let("g",
Lambda("f", Ident("5")),
Apply(Ident("g"), Ident("g"))),
// example that demonstrates generic and non-generic variables:
// fn g ⇒ let f = fn x ⇒ g in pair (f 3, f true)
Lambda("g",
Let("f",
Lambda("x", Ident("g")),
Apply(
Apply(Ident("pair"),
Apply(Ident("f"), Ident("3"))
),
Apply(Ident("f"), Ident("true"))))),
// Function composition
// fn f (fn g (fn arg (f g arg)))
Lambda("f", Lambda("g", Lambda("arg", Apply(Ident("g"), Apply(Ident("f"), Ident("arg"))))))
)
for(eg ← examples){
tryexp(myenv, eg)
}
}
def tryexp(env: TypeSystem.Env, ast: SyntaxNode) {
print(SyntaxNode.string(ast) + " : ")
try {
val t = TypeSystem.analyse(ast, env)
print(TypeSystem.string(t))
} catch {
case t: ParseError ⇒ print(t.getMessage)
case t: TypeError ⇒ print(t.getMessage)
}
println
}
}
HindleyMilner.main(argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment