Last active
November 13, 2021 10:38
-
-
Save melwyn95/6484684ace51832f3b0d6ac688670ea5 to your computer and use it in GitHub Desktop.
AlgorithmW Type Inference (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.65.7733&rep=rep1&type=pdf)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* Implementation for Algorithm-W tutorial in OCaml | |
Reference: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.65.7733&rep=rep1&type=pdf | |
*) | |
(* trivial utils *) | |
module SMap = Map.Make(String) | |
module SSet = Set.Make(String) | |
let rec zip xs ys = | |
match xs, ys with | |
[], [] -> [] | |
| x::xs, y::ys -> (x,y)::zip xs ys | |
| _ -> failwith "zip: list lengths not equal" | |
let rec map_of_list xs = | |
match xs with | |
[] -> SMap.empty | |
| (k,v)::xs -> SMap.add k v (map_of_list xs) | |
type lit = LInt of int | |
| LBool of bool | |
type exp = EVar of string | |
| ELit of lit | |
| EApp of (exp * exp) | |
| EAbs of (string * exp) | |
| ELet of (string * exp * exp) | |
module rec Type : sig | |
type t = TVar of string | |
| TInt | |
| TBool | |
| TFun of (t * t) | |
val ftv : t -> SSet.t | |
val apply : Subst.t -> t -> t | |
end = struct | |
type t = TVar of string | |
| TInt | |
| TBool | |
| TFun of (t * t) | |
let rec ftv = function | |
TVar v -> SSet.singleton v | |
| TInt -> SSet.empty | |
| TBool -> SSet.empty | |
| TFun (binder,body) -> SSet.union (ftv binder) (ftv body) | |
let rec apply s = function | |
TVar v -> | |
(match SMap.find_opt v s with | |
Some typ -> typ | |
| None -> TVar v) | |
| TFun (binder,body) -> TFun (apply s binder,apply s body) | |
| TInt -> TInt | |
| TBool -> TBool | |
end | |
and Subst : sig | |
type t = Type.t SMap.t | |
val null : t | |
val compose : t -> t -> t | |
end = struct | |
type t = Type.t SMap.t | |
let null = SMap.empty | |
let compose s1 s2 = SMap.union (fun _ v _ -> Some v) (SMap.map (Type.apply s1) s2) s1 | |
end | |
and Scheme : sig | |
type t = (string list * Type.t) | |
val ftv : t -> SSet.t | |
val apply : Subst.t -> t -> t | |
end = struct | |
type t = (string list * Type.t) | |
let ftv (vars,typ) = SSet.diff (Type.ftv typ) (SSet.of_list vars) | |
let apply s (vars,typ) = | |
(vars,Type.apply (List.fold_right (SMap.remove) vars s) typ) | |
end | |
and TypeEnv : sig | |
type t = Scheme.t SMap.t | |
val remove : t -> string -> t | |
val ftv : t -> SSet.t | |
val apply : Subst.t -> t -> t | |
val empty : unit -> t | |
end = struct | |
type t = Scheme.t SMap.t | |
let remove env var = SMap.remove var env | |
let ftv env = | |
let schemes = List.map snd @@ SMap.bindings env in | |
List.fold_left (fun s scheme -> SSet.union s (Scheme.ftv scheme)) SSet.empty schemes | |
let apply subst env = | |
SMap.map (Scheme.apply subst) env | |
let empty = fun () -> SMap.empty | |
end | |
module PP = struct | |
open Type | |
let rec pp_typ t = | |
match t with | |
TVar v -> v | |
| TInt -> "int" | |
| TBool -> "bool" | |
| TFun (binder,body) -> pp_typ binder ^ " -> " ^ pp_typ body | |
let rec pp_lit l = | |
match l with | |
LInt i -> string_of_int i | |
| LBool true -> "true" | |
| LBool false -> "false" | |
let rec pp_exp e = | |
match e with | |
EVar v -> v | |
| ELit l -> pp_lit l | |
| EApp (f,e) -> "(" ^ pp_exp f ^ " @@ " ^ pp_exp e ^ ")" | |
| EAbs (p,e) -> "(fun " ^ p ^ " -> " ^ pp_exp e ^ ")" | |
| ELet (b,r,l) -> "let " ^ b ^ " = " ^ pp_exp r ^ " in \n\t" ^ pp_exp l | |
end | |
let generalize env typ = | |
let vars = SSet.elements @@ SSet.diff (Type.ftv typ) (TypeEnv.ftv env) in | |
(vars,typ) | |
let ctr = ref 0 | |
let new_ty_var s = | |
incr ctr; | |
let open Type in | |
TVar (s ^ string_of_int (!ctr)) | |
let instantiate scheme = | |
let (vars, typ) = scheme in | |
let nvars = List.map (new_ty_var) vars in | |
let subst = map_of_list @@ zip vars nvars in | |
Type.apply subst typ | |
let var_bind v t = | |
let open Type in | |
match t with | |
TVar u when u = v -> Subst.null | |
| _ when SSet.mem v (Type.ftv t) -> failwith ("occur check fails: " ^ v ^ " vs. " ^ PP.pp_typ t) | |
| _ -> SMap.singleton v t | |
let rec mgu t1 t2 = | |
let open Type in | |
match t1, t2 with | |
TFun (p1,b1), TFun (p2,b2) -> | |
let s1 = mgu p1 p2 in | |
let s2 = mgu (Type.apply s1 b1) (Type.apply s1 b2) in | |
Subst.compose s1 s2 | |
| TInt, TInt -> Subst.null | |
| TBool, TBool -> Subst.null | |
| TVar v, t -> var_bind v t | |
| t, TVar v -> var_bind v t | |
| _ -> failwith ("types do not unify: " ^ PP.pp_typ t2 ^ " vs. " ^ PP.pp_typ t2) | |
let infer_lit = function | |
LInt _ -> (Subst.null, Type.TInt) | |
| LBool _ -> (Subst.null, Type.TBool) | |
let rec infer env expr = | |
let open Type in | |
match expr with | |
EVar v -> | |
(match SMap.find_opt v env with | |
Some s -> (Subst.null, instantiate s) | |
| None -> failwith ("unbound variable: " ^ v)) | |
| ELit l -> infer_lit l | |
| EApp (f, args) -> | |
let tv = new_ty_var "a" in | |
let s1, t1 = infer env f in | |
let s2, t2 = infer (TypeEnv.apply s1 env) args in | |
let s3 = mgu (Type.apply s2 t1) (TFun (t2, tv)) in | |
(Subst.compose s3 (Subst.compose s2 s1), Type.apply s3 tv) | |
| EAbs (p, b) -> | |
let tv = new_ty_var "a" in | |
let env' = SMap.remove p env in | |
let env'' = SMap.union (fun _ v _ -> Some v) env' (SMap.singleton p ([], tv)) in | |
let s1, t1 = infer env'' b in | |
(s1, TFun (Type.apply s1 tv, t1)) | |
| ELet (x, e, b) -> | |
let s1, t1 = infer env e in | |
let env' = SMap.remove x env in | |
let t' = generalize (TypeEnv.apply s1 env) t1 in | |
let env'' = SMap.add x t' env' in | |
let s2, t2 = infer (TypeEnv.apply s1 env'') b in | |
(Subst.compose s1 s2, t2) | |
let type_inference env expr = | |
let s, t = infer env expr in | |
Type.apply s t | |
module Test = struct | |
open Type | |
let e0 = ELet ("id", | |
EAbs ("x", EVar "x"), | |
EVar "id" | |
) | |
let e1 = ELet ("id", | |
EAbs ("x", EVar "x"), | |
EApp (EVar "id", EVar "id") | |
) | |
let e2 = ELet ("id", | |
EAbs ("x", ELet ("y", | |
EVar "x", | |
EVar "y" | |
) | |
), | |
EApp (EVar "id", EVar "id") | |
) | |
let e3 = ELet ("id", | |
EAbs ("x", ELet ("y", | |
EVar "x", | |
EVar "y")), | |
EApp (EApp (EVar "id", EVar "id"), ELit (LInt 2))) | |
let e4 = ELet ("id", | |
EAbs ("x", EApp (EVar "x", EVar "x")), | |
EVar "id") | |
let e5 = EAbs ("m", | |
ELet ("y", | |
EVar "m", | |
ELet ("x", | |
EApp (EVar "y", ELit (LBool true)), | |
EVar "x" | |
) | |
) | |
) | |
let run_test exp = | |
print_endline (PP.pp_exp exp); | |
let t = type_inference (TypeEnv.empty ()) exp in | |
print_endline (PP.pp_typ t); | |
print_newline (); | |
end;; | |
Test.run_test Test.e0 ;; | |
Test.run_test Test.e1 ;; | |
Test.run_test Test.e2 ;; | |
Test.run_test Test.e3 ;; | |
Test.run_test Test.e5 ;; | |
Test.run_test Test.e4 ;; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment