Skip to content

Instantly share code, notes, and snippets.

@melwyn95
Last active November 13, 2021 10:38
Show Gist options
  • Save melwyn95/6484684ace51832f3b0d6ac688670ea5 to your computer and use it in GitHub Desktop.
Save melwyn95/6484684ace51832f3b0d6ac688670ea5 to your computer and use it in GitHub Desktop.
(* Implementation for Algorithm-W tutorial in OCaml
Reference: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.65.7733&rep=rep1&type=pdf
*)
(* trivial utils *)
module SMap = Map.Make(String)
module SSet = Set.Make(String)
let rec zip xs ys =
match xs, ys with
[], [] -> []
| x::xs, y::ys -> (x,y)::zip xs ys
| _ -> failwith "zip: list lengths not equal"
let rec map_of_list xs =
match xs with
[] -> SMap.empty
| (k,v)::xs -> SMap.add k v (map_of_list xs)
type lit = LInt of int
| LBool of bool
type exp = EVar of string
| ELit of lit
| EApp of (exp * exp)
| EAbs of (string * exp)
| ELet of (string * exp * exp)
module rec Type : sig
type t = TVar of string
| TInt
| TBool
| TFun of (t * t)
val ftv : t -> SSet.t
val apply : Subst.t -> t -> t
end = struct
type t = TVar of string
| TInt
| TBool
| TFun of (t * t)
let rec ftv = function
TVar v -> SSet.singleton v
| TInt -> SSet.empty
| TBool -> SSet.empty
| TFun (binder,body) -> SSet.union (ftv binder) (ftv body)
let rec apply s = function
TVar v ->
(match SMap.find_opt v s with
Some typ -> typ
| None -> TVar v)
| TFun (binder,body) -> TFun (apply s binder,apply s body)
| TInt -> TInt
| TBool -> TBool
end
and Subst : sig
type t = Type.t SMap.t
val null : t
val compose : t -> t -> t
end = struct
type t = Type.t SMap.t
let null = SMap.empty
let compose s1 s2 = SMap.union (fun _ v _ -> Some v) (SMap.map (Type.apply s1) s2) s1
end
and Scheme : sig
type t = (string list * Type.t)
val ftv : t -> SSet.t
val apply : Subst.t -> t -> t
end = struct
type t = (string list * Type.t)
let ftv (vars,typ) = SSet.diff (Type.ftv typ) (SSet.of_list vars)
let apply s (vars,typ) =
(vars,Type.apply (List.fold_right (SMap.remove) vars s) typ)
end
and TypeEnv : sig
type t = Scheme.t SMap.t
val remove : t -> string -> t
val ftv : t -> SSet.t
val apply : Subst.t -> t -> t
val empty : unit -> t
end = struct
type t = Scheme.t SMap.t
let remove env var = SMap.remove var env
let ftv env =
let schemes = List.map snd @@ SMap.bindings env in
List.fold_left (fun s scheme -> SSet.union s (Scheme.ftv scheme)) SSet.empty schemes
let apply subst env =
SMap.map (Scheme.apply subst) env
let empty = fun () -> SMap.empty
end
module PP = struct
open Type
let rec pp_typ t =
match t with
TVar v -> v
| TInt -> "int"
| TBool -> "bool"
| TFun (binder,body) -> pp_typ binder ^ " -> " ^ pp_typ body
let rec pp_lit l =
match l with
LInt i -> string_of_int i
| LBool true -> "true"
| LBool false -> "false"
let rec pp_exp e =
match e with
EVar v -> v
| ELit l -> pp_lit l
| EApp (f,e) -> "(" ^ pp_exp f ^ " @@ " ^ pp_exp e ^ ")"
| EAbs (p,e) -> "(fun " ^ p ^ " -> " ^ pp_exp e ^ ")"
| ELet (b,r,l) -> "let " ^ b ^ " = " ^ pp_exp r ^ " in \n\t" ^ pp_exp l
end
let generalize env typ =
let vars = SSet.elements @@ SSet.diff (Type.ftv typ) (TypeEnv.ftv env) in
(vars,typ)
let ctr = ref 0
let new_ty_var s =
incr ctr;
let open Type in
TVar (s ^ string_of_int (!ctr))
let instantiate scheme =
let (vars, typ) = scheme in
let nvars = List.map (new_ty_var) vars in
let subst = map_of_list @@ zip vars nvars in
Type.apply subst typ
let var_bind v t =
let open Type in
match t with
TVar u when u = v -> Subst.null
| _ when SSet.mem v (Type.ftv t) -> failwith ("occur check fails: " ^ v ^ " vs. " ^ PP.pp_typ t)
| _ -> SMap.singleton v t
let rec mgu t1 t2 =
let open Type in
match t1, t2 with
TFun (p1,b1), TFun (p2,b2) ->
let s1 = mgu p1 p2 in
let s2 = mgu (Type.apply s1 b1) (Type.apply s1 b2) in
Subst.compose s1 s2
| TInt, TInt -> Subst.null
| TBool, TBool -> Subst.null
| TVar v, t -> var_bind v t
| t, TVar v -> var_bind v t
| _ -> failwith ("types do not unify: " ^ PP.pp_typ t2 ^ " vs. " ^ PP.pp_typ t2)
let infer_lit = function
LInt _ -> (Subst.null, Type.TInt)
| LBool _ -> (Subst.null, Type.TBool)
let rec infer env expr =
let open Type in
match expr with
EVar v ->
(match SMap.find_opt v env with
Some s -> (Subst.null, instantiate s)
| None -> failwith ("unbound variable: " ^ v))
| ELit l -> infer_lit l
| EApp (f, args) ->
let tv = new_ty_var "a" in
let s1, t1 = infer env f in
let s2, t2 = infer (TypeEnv.apply s1 env) args in
let s3 = mgu (Type.apply s2 t1) (TFun (t2, tv)) in
(Subst.compose s3 (Subst.compose s2 s1), Type.apply s3 tv)
| EAbs (p, b) ->
let tv = new_ty_var "a" in
let env' = SMap.remove p env in
let env'' = SMap.union (fun _ v _ -> Some v) env' (SMap.singleton p ([], tv)) in
let s1, t1 = infer env'' b in
(s1, TFun (Type.apply s1 tv, t1))
| ELet (x, e, b) ->
let s1, t1 = infer env e in
let env' = SMap.remove x env in
let t' = generalize (TypeEnv.apply s1 env) t1 in
let env'' = SMap.add x t' env' in
let s2, t2 = infer (TypeEnv.apply s1 env'') b in
(Subst.compose s1 s2, t2)
let type_inference env expr =
let s, t = infer env expr in
Type.apply s t
module Test = struct
open Type
let e0 = ELet ("id",
EAbs ("x", EVar "x"),
EVar "id"
)
let e1 = ELet ("id",
EAbs ("x", EVar "x"),
EApp (EVar "id", EVar "id")
)
let e2 = ELet ("id",
EAbs ("x", ELet ("y",
EVar "x",
EVar "y"
)
),
EApp (EVar "id", EVar "id")
)
let e3 = ELet ("id",
EAbs ("x", ELet ("y",
EVar "x",
EVar "y")),
EApp (EApp (EVar "id", EVar "id"), ELit (LInt 2)))
let e4 = ELet ("id",
EAbs ("x", EApp (EVar "x", EVar "x")),
EVar "id")
let e5 = EAbs ("m",
ELet ("y",
EVar "m",
ELet ("x",
EApp (EVar "y", ELit (LBool true)),
EVar "x"
)
)
)
let run_test exp =
print_endline (PP.pp_exp exp);
let t = type_inference (TypeEnv.empty ()) exp in
print_endline (PP.pp_typ t);
print_newline ();
end;;
Test.run_test Test.e0 ;;
Test.run_test Test.e1 ;;
Test.run_test Test.e2 ;;
Test.run_test Test.e3 ;;
Test.run_test Test.e5 ;;
Test.run_test Test.e4 ;;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment