Skip to content

Instantly share code, notes, and snippets.

@zehnpaard
Created October 31, 2022 03:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zehnpaard/865e8736b8793a47a59e2996c7a5bc91 to your computer and use it in GitHub Desktop.
Save zehnpaard/865e8736b8793a47a59e2996c7a5bc91 to your computer and use it in GitHub Desktop.
Hindley Milner Type Inference with Unit, Bool
type ty =
| TVar of tvar ref
| TArrow of ty * ty
| TUnit
| TBool
and tvar =
| Unbound of int * int
| Link of ty
| Generic of int
type exp =
| EVar of string
| EAbs of string * exp
| EApp of exp * exp
| ELet of string * exp * exp
| EUnit
| ETrue
| EFalse
| EIf of exp * exp * exp
let new_tvar =
let i = ref 0 in
let f level = incr i; TVar(ref @@ Unbound(!i, level)) in
f
let rec occursin id = function
| TVar{contents=Unbound(id1, _)} -> id = id1
| TVar{contents=Link t} -> occursin id t
| TVar{contents=Generic _} -> false
| TArrow(tparam, tret) -> occursin id tparam || occursin id tret
| TUnit | TBool -> false
let rec adjustlevel level = function
| TVar({contents=Unbound(id1, level1)} as tvar) ->
if level < level1 then tvar := Unbound(id1, level)
| TVar{contents=Link t} -> adjustlevel level t
| TVar{contents=Generic _} -> ()
| TArrow(tparam, tret) -> adjustlevel level tparam; adjustlevel level tret
| TUnit | TBool -> ()
let rec unify t1 t2 = match t1, t2 with
| _, _ when t1 = t2 -> ()
| TArrow(tparam1, tret1), TArrow(tparam2, tret2) ->
unify tparam1 tparam2; unify tret1 tret2
| TVar{contents=Link t1}, t2 | t1, TVar{contents=Link t2} -> unify t1 t2
| TVar({contents=Unbound(id,level)} as tvar), ty | ty, TVar({contents=Unbound(id,level)} as tvar) ->
if occursin id ty then failwith "Unification failed due to occurs check";
adjustlevel level ty;
tvar := Link ty
| _ -> failwith "Cannot unify types"
let rec match_fun_ty tfunc targ = match tfunc with
| TArrow(tparam,tret) -> unify tparam targ; tret
| TVar {contents=Link ty} -> match_fun_ty ty targ
| TVar ({contents=Unbound(_,level)} as tvar) ->
let tparam = new_tvar level in
let tret = new_tvar level in
tvar := Link(TArrow(tparam,tret));
unify tparam targ;
tret
| TVar {contents=Generic _} -> failwith "Generic type encountered, expecting arrow or instantiated variable"
| TUnit | TBool -> failwith "Non-arrow type found in function position"
let rec generalize level ty = match ty with
| TVar{contents=Unbound(id1,level1)} when level < level1 -> TVar(ref(Generic id1))
| TVar{contents=Unbound _} -> ty
| TVar{contents=Link ty} -> generalize level ty
| TVar{contents=Generic _} -> ty
| TArrow(tparam, tret) -> TArrow(generalize level tparam, generalize level tret)
| TUnit | TBool -> ty
let instantiate level ty =
let id_var_hash = Hashtbl.create 10 in
let rec f ty = match ty with
| TVar{contents=Generic id} ->
(try Hashtbl.find id_var_hash id
with Not_found ->
let var = new_tvar level in
Hashtbl.add id_var_hash id var;
var)
| TVar{contents=Unbound _} -> ty
| TVar{contents=Link ty} -> f ty
| TArrow(tparam, tret) -> TArrow(f tparam, f tret)
| TUnit | TBool -> ty
in f ty
let rec typeof env level = function
| EVar s -> instantiate level (List.assoc s env)
| EAbs(sparam, fbody) ->
let tparam = new_tvar level in
let tret = typeof ((sparam,tparam)::env) level fbody in
TArrow(tparam,tret)
| EApp(func, arg) ->
let tfunc = typeof env level func in
let targ = typeof env level arg in
match_fun_ty tfunc targ
| ELet(svar, e, ebody) ->
let tvar = typeof env (level+1) e in
let tgen = generalize level tvar in
typeof ((svar,tgen)::env) level ebody
| EUnit -> TUnit
| ETrue | EFalse -> TBool
| EIf(cond,e1,e2) ->
unify (typeof env level cond) TBool;
let te1 = typeof env level e1 in
unify te1 (typeof env level e2);
te1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment