Last active
November 3, 2022 08:04
-
-
Save zehnpaard/85f75056937be8b577f3c6e187000aa0 to your computer and use it in GitHub Desktop.
Hindley Milner Type Inference with Unit, Bool, Int, Tuple
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type ty = | |
| TVar of tvar ref | |
| TArrow of ty * ty | |
| TUnit | |
| TBool | |
| TInt | |
| TTuple of ty list | |
and tvar = | |
| Unbound of int * int | |
| Link of ty | |
| Generic of int | |
type exp = | |
| EVar of string | |
| EAbs of string * exp | |
| EApp of exp * exp | |
| ELet of string * exp * exp | |
| EUnit | |
| ETrue | |
| EFalse | |
| EIf of exp * exp * exp | |
| EInt of int | |
| EAdd of exp * exp | |
| EIsZero of exp | |
| ETuple of exp list | |
| ETupleAccess of exp * int * int | |
let new_tvar = | |
let i = ref 0 in | |
let f level = incr i; TVar(ref @@ Unbound(!i, level)) in | |
f | |
let rec occursin id = function | |
| TVar{contents=Unbound(id1, _)} -> id = id1 | |
| TVar{contents=Link t} -> occursin id t | |
| TVar{contents=Generic _} -> false | |
| TArrow(tparam, tret) -> occursin id tparam || occursin id tret | |
| TUnit | TBool | TInt -> false | |
| TTuple ts -> List.exists (occursin id) ts | |
let rec adjustlevel level = function | |
| TVar({contents=Unbound(id1, level1)} as tvar) -> | |
if level < level1 then tvar := Unbound(id1, level) | |
| TVar{contents=Link t} -> adjustlevel level t | |
| TVar{contents=Generic _} -> () | |
| TArrow(tparam, tret) -> adjustlevel level tparam; adjustlevel level tret | |
| TUnit | TBool | TInt -> () | |
| TTuple ts -> List.iter (adjustlevel level) ts | |
let rec unify t1 t2 = match t1, t2 with | |
| _, _ when t1 = t2 -> () | |
| TArrow(tparam1, tret1), TArrow(tparam2, tret2) -> | |
unify tparam1 tparam2; unify tret1 tret2 | |
| TVar{contents=Link t1}, t2 | t1, TVar{contents=Link t2} -> unify t1 t2 | |
| TVar({contents=Unbound(id,level)} as tvar), ty | ty, TVar({contents=Unbound(id,level)} as tvar) -> | |
if occursin id ty then failwith "Unification failed due to occurs check"; | |
adjustlevel level ty; | |
tvar := Link ty | |
| TTuple ts1, TTuple ts2 -> List.iter2 unify ts1 ts2 | |
| _ -> failwith "Cannot unify types" | |
let rec match_fun_ty tfunc targ = match tfunc with | |
| TArrow(tparam,tret) -> unify tparam targ; tret | |
| TVar {contents=Link ty} -> match_fun_ty ty targ | |
| TVar ({contents=Unbound(_,level)} as tvar) -> | |
let tparam = new_tvar level in | |
let tret = new_tvar level in | |
tvar := Link(TArrow(tparam,tret)); | |
unify tparam targ; | |
tret | |
| TVar {contents=Generic _} -> failwith "Generic type encountered, expecting arrow or instantiated variable" | |
| TUnit | TBool | TInt | TTuple _ -> failwith "Non-arrow type found in function position" | |
let rec match_tuple_ty ttup i n = match ttup with | |
| TTuple ts -> | |
if List.length ts != n then failwith "Incorrect tuple arity"; | |
List.nth ts i | |
| TVar {contents=Link ty} -> match_tuple_ty ty i n | |
| TVar ({contents=Unbound(_,level)} as tvar) -> | |
let ts = List.init n (fun _ -> new_tvar level) in | |
tvar := Link(TTuple ts); | |
List.nth ts i | |
| TVar {contents=Generic _} -> failwith "Generic type encountered, expecting tuple or instantiated variable" | |
| TUnit | TBool | TInt | TArrow _ -> failwith "Non-tuple type found in tuple position" | |
let rec generalize level ty = match ty with | |
| TVar{contents=Unbound(id1,level1)} when level < level1 -> TVar(ref(Generic id1)) | |
| TVar{contents=Unbound _} -> ty | |
| TVar{contents=Link ty} -> generalize level ty | |
| TVar{contents=Generic _} -> ty | |
| TArrow(tparam, tret) -> TArrow(generalize level tparam, generalize level tret) | |
| TUnit | TBool | TInt -> ty | |
| TTuple ts -> TTuple (List.map (generalize level) ts) | |
let instantiate level ty = | |
let id_var_hash = Hashtbl.create 10 in | |
let rec f ty = match ty with | |
| TVar{contents=Generic id} -> | |
(try Hashtbl.find id_var_hash id | |
with Not_found -> | |
let var = new_tvar level in | |
Hashtbl.add id_var_hash id var; | |
var) | |
| TVar{contents=Unbound _} -> ty | |
| TVar{contents=Link ty} -> f ty | |
| TArrow(tparam, tret) -> TArrow(f tparam, f tret) | |
| TUnit | TBool | TInt -> ty | |
| TTuple ts -> TTuple (List.map f ts) | |
in f ty | |
let rec typeof env level = function | |
| EVar s -> instantiate level (List.assoc s env) | |
| EAbs(sparam, fbody) -> | |
let tparam = new_tvar level in | |
let tret = typeof ((sparam,tparam)::env) level fbody in | |
TArrow(tparam,tret) | |
| EApp(func, arg) -> | |
let tfunc = typeof env level func in | |
let targ = typeof env level arg in | |
match_fun_ty tfunc targ | |
| ELet(svar, e, ebody) -> | |
let tvar = typeof env (level+1) e in | |
let tgen = generalize level tvar in | |
typeof ((svar,tgen)::env) level ebody | |
| EUnit -> TUnit | |
| ETrue | EFalse -> TBool | |
| EIf(cond,e1,e2) -> | |
unify (typeof env level cond) TBool; | |
let te1 = typeof env level e1 in | |
unify te1 (typeof env level e2); | |
te1 | |
| EInt _ -> TInt | |
| EAdd(e1,e2) -> | |
unify (typeof env level e1) TInt; | |
unify (typeof env level e2) TInt; | |
TInt | |
| EIsZero e -> | |
unify (typeof env level e) TInt; | |
TBool | |
| ETuple es -> TTuple (List.map (typeof env level) es) | |
| ETupleAccess(e,i,n) -> match_tuple_ty (typeof env level e) i n |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment