Skip to content

Instantly share code, notes, and snippets.

@ddrone
Created September 28, 2011 10:48
Show Gist options
  • Save ddrone/1247631 to your computer and use it in GitHub Desktop.
Save ddrone/1247631 to your computer and use it in GitHub Desktop.
Hindley-Millner type inference
import java.util.TreeMap;
public class Main {
// Expression type definitions
static abstract class Expression implements Comparable<Expression> {
@Override
public int compareTo(Expression arg0) {
return this.toString().compareTo(arg0.toString());
}
}
static class VariableExpression extends Expression {
String varName;
public VariableExpression(String name) {
varName = name;
}
@Override
public String toString() {
return varName;
}
}
static class AbstractionExpression extends Expression {
String varName;
Expression abstrBody;
public AbstractionExpression(String name, Expression body) {
varName = name;
abstrBody = body;
}
@Override
public String toString() {
return "\\" + varName + "." + abstrBody.toString();
}
}
static class ApplicationExpression extends Expression {
Expression applFunction;
Expression applArgument;
public ApplicationExpression(Expression func, Expression arg) {
applFunction = func;
applArgument = arg;
}
@Override
public String toString() {
return "(" + applFunction.toString() + " " + applArgument.toString() + ")";
}
}
// "Type" type definition
static abstract class Type {
}
static class VariableType extends Type {
String varName;
public VariableType(String name) {
varName = name;
}
@Override
public String toString() {
return varName;
}
}
static class ArrowType extends Type {
Type arrowLeft;
Type arrowRight;
public ArrowType(Type left, Type right) {
arrowLeft = left;
arrowRight = right;
}
@Override
public String toString() {
String result;
if (arrowLeft instanceof ArrowType) {
result = "(" + arrowLeft.toString() + ")";
} else {
result = arrowLeft.toString();
}
result = result + " -> " + arrowRight.toString();
return result;
}
}
// Name generator
static interface NameGenerator {
public String getNext();
}
static class SimpleNameGenerator implements NameGenerator {
char curChar;
int curNumber;
public SimpleNameGenerator() {
curChar = 'a';
curNumber = 0;
}
public String getNext() {
String result = Character.toString(curChar);
if (curNumber > 0) {
result += Integer.toString(curNumber);
}
if (curChar == 'z') {
curChar = 'a';
curNumber++;
} else {
curChar++;
}
return result;
}
}
// Type inference
static void inferType(Expression expr, NameGenerator gen, TreeMap<Expression, Type> environment) throws UnificationError, InferenceError {
if (expr instanceof VariableExpression) {
if (environment.containsKey(expr)) {
return;
} else {
environment.put(expr, new VariableType(gen.getNext()));
}
} else if (expr instanceof ApplicationExpression) {
ApplicationExpression e = (ApplicationExpression) expr;
inferType(e.applArgument, gen, environment);
Type argType = environment.get(e.applArgument);
if (e.applFunction instanceof VariableExpression) {
VariableExpression v = (VariableExpression) e.applFunction;
if (environment.containsKey(v)) {
Type funcType = environment.get(v);
unifyTypes(new ArrowType(argType, new VariableType(gen.getNext())), funcType, environment);
funcType = environment.get(v);
if (funcType instanceof ArrowType) {
environment.put(expr, ((ArrowType) funcType).arrowRight);
} else {
System.err.println("ArrowType expected");
throw new InferenceError();
}
} else {
String curName = gen.getNext();
environment.put(expr, new VariableType(curName));
environment.put(e.applFunction, new ArrowType(argType, new VariableType(curName)));
}
} else if (e.applFunction instanceof ApplicationExpression) {
ApplicationExpression appl = (ApplicationExpression) e.applFunction;
inferType(appl, gen, environment);
Type funcType = environment.get(appl);
unifyTypes(funcType, new ArrowType(argType, new VariableType(gen.getNext())), environment);
funcType = environment.get(e.applFunction);
if (funcType instanceof ArrowType) {
ArrowType a = (ArrowType) funcType;
environment.put(expr, a.arrowRight);
} else {
System.err.println("ArrowType expected");
throw new InferenceError();
}
} else if (e.applFunction instanceof AbstractionExpression) {
AbstractionExpression abstr = (AbstractionExpression) e.applFunction;
inferType(abstr, gen, environment);
Type funcType = environment.get(abstr);
if (funcType instanceof ArrowType) {
ArrowType arr = (ArrowType) funcType;
if (areTypesCompatible(arr.arrowLeft, argType)) {
unifyTypes(arr.arrowLeft, argType, environment);
funcType = environment.get(abstr);
if (funcType instanceof ArrowType) {
environment.put(expr, ((ArrowType) funcType).arrowRight);
} else {
System.err.println("ArrowType expected");
throw new InferenceError();
}
} else {
throw new InferenceError();
}
} else {
System.err.println("ArrowType expected");
throw new InferenceError();
}
}
} else if (expr instanceof AbstractionExpression) {
AbstractionExpression e = (AbstractionExpression) expr;
if (environment.containsKey(new VariableExpression (e.varName))) {
System.err.println("Duplicate bound variable in lambda abstraction!");
throw new InferenceError();
} else {
inferType(e.abstrBody, gen, environment);
Type bodyType = environment.get(e.abstrBody);
if (environment.containsKey(new VariableExpression (e.varName))) {
Type argType = environment.get(new VariableExpression (e.varName));
environment.put(expr, new ArrowType(argType, bodyType));
} else {
String curName = gen.getNext();
environment.put(new VariableExpression (e.varName), new VariableType(curName));
environment.put(expr, new ArrowType(new VariableType(curName), bodyType));
}
}
}
}
// Helper function
static Type getType(Expression expr) throws UnificationError, InferenceError {
TreeMap<Expression, Type> env = new TreeMap<Expression, Type>();
inferType(expr, new SimpleNameGenerator(), env);
return env.get(expr);
}
// Type compatibility checker
static boolean areTypesCompatible(Type t1, Type t2) {
if (t1 instanceof VariableType) {
return true;
} else if (t1 instanceof ArrowType && t2 instanceof ArrowType) {
ArrowType a1 = (ArrowType) t1;
ArrowType a2 = (ArrowType) t2;
return (areTypesCompatible(a1.arrowLeft, a2.arrowLeft) && areTypesCompatible(a1.arrowRight, a2.arrowRight));
}
return false;
}
// Type variable substitution
static Type substituteTypeVariable(String var, Type replacement, Type haystack) {
if (haystack instanceof VariableType) {
VariableType t = (VariableType) haystack;
if (t.varName.equals(var)) {
return replacement;
} else {
return haystack;
}
} else if (haystack instanceof ArrowType) {
ArrowType t = (ArrowType) haystack;
return new ArrowType(substituteTypeVariable(var, replacement, t.arrowLeft),
substituteTypeVariable(var, replacement, t.arrowRight));
}
return null;
}
// Type unification exception
static class UnificationError extends Exception {
private static final long serialVersionUID = 728975700118640646L;
}
// Type inference exception
static class InferenceError extends Exception {
private static final long serialVersionUID = 9186955675628337698L;
}
// Type unification
static void unifyTypes(Type t1, Type t2, TreeMap<Expression, Type> env) throws UnificationError {
if (t1 instanceof VariableType) {
VariableType v = (VariableType) t1;
for (Expression e : env.keySet()) {
env.put(e, substituteTypeVariable(v.varName, t2, env.get(e)));
}
} else if (t1 instanceof ArrowType && t2 instanceof ArrowType) {
ArrowType a1 = (ArrowType) t1;
ArrowType a2 = (ArrowType) t2;
unifyTypes(a1.arrowLeft, a2.arrowLeft, env);
unifyTypes(a1.arrowRight, a2.arrowRight, env);
} else {
throw new UnificationError();
}
}
static void printTypedExpression(Expression expr) {
Type t = null;
try {
t = getType(expr);
} catch (Exception e) {
// e.printStackTrace();
}
System.out.print(expr.toString() + " :: ");
if (t == null) {
System.out.println("Type inference error");
} else {
System.out.println(t.toString());
}
}
static class ParseError extends Exception {
private static final long serialVersionUID = 4127019321896484621L;
}
static class ExpressionParser {
char[] str;
int pos;
public ExpressionParser(String s) {
str = s.toCharArray();
pos = 0;
}
public static boolean isAllowedChar(char c) {
return ('a' <= c && c <= 'z');
}
public Expression parseExpression() throws ParseError {
if (pos >= str.length) {
throw new ParseError();
}
if (isAllowedChar(str[pos])) {
return new VariableExpression(Character.toString(str[pos++]));
} else if (str[pos] == '\\') {
pos++;
Expression arg = parseExpression();
if (arg instanceof VariableExpression) {
VariableExpression varArg = (VariableExpression) arg;
if (pos >= str.length) {
throw new ParseError();
}
if (str[pos] == '.') {
pos++;
return new AbstractionExpression(varArg.varName, parseExpression());
} else {
throw new ParseError();
}
} else {
throw new ParseError();
}
} else if (str[pos] == '(') {
pos++;
Expression func = parseExpression();
if (pos >= str.length) {
throw new ParseError();
}
if (str[pos] == ' ') {
pos++;
Expression arg = parseExpression();
if (pos >= str.length) {
throw new ParseError();
}
if (str[pos] == ')') {
pos++;
return new ApplicationExpression(func, arg);
} else {
throw new ParseError();
}
} else {
throw new ParseError();
}
} else {
throw new ParseError();
}
}
}
static void printParsedTypedExpression(String str) {
try {
printTypedExpression(new ExpressionParser(str).parseExpression());
} catch (ParseError e) {
// e.printStackTrace();
}
}
public static void main(String[] args) {
Expression expr = new AbstractionExpression("x",
new AbstractionExpression("y", new ApplicationExpression(
new VariableExpression("x"),
new VariableExpression("y"))));
Expression expr2 = new AbstractionExpression("x",
new AbstractionExpression("y",
new VariableExpression("x")));
Expression expr3 = new AbstractionExpression("x",
new AbstractionExpression("y",
new AbstractionExpression("z",
new ApplicationExpression(
new ApplicationExpression(new VariableExpression("x"), new VariableExpression("z")),
new ApplicationExpression(new VariableExpression("y"), new VariableExpression("z"))))));
printTypedExpression(expr);
printTypedExpression(expr2);
printTypedExpression(expr3);
printParsedTypedExpression("\\x.\\y.\\z.((y x) z)");
printParsedTypedExpression("\\x.\\x.(x x)");
printParsedTypedExpression("\\y.(\\x.(y x) y)");
printParsedTypedExpression("\\x.\\y.((x y) y)");
printParsedTypedExpression("\\x.\\x.x");
return;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment