Skip to content

Instantly share code, notes, and snippets.

@aradarbel10
Last active July 3, 2023 15:12
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aradarbel10/837aa65d2f06ac6710c6fbe479909b4c to your computer and use it in GitHub Desktop.
Save aradarbel10/837aa65d2f06ac6710c6fbe479909b4c to your computer and use it in GitHub Desktop.
minimal STLC type inference with mutable metavars
(* language definition *)
type nom = string
type bop = Add | Sub | Mul
type typ =
| Int | Arrow of typ * typ | Meta of meta
and meta = meta_state ref
and meta_state =
| Solved of typ
| Unsolved of nom (* keep name for pretty printing *)
type expr =
| Num of int
| Bop of bop * expr * expr
| Var of nom
| Lam of nom * expr
| App of expr * expr
| Let of nom * expr * expr
(* pretty printing *)
let rec string_of_typ : typ -> string = function
| Int -> "int"
| Arrow (t0, t1) -> "(" ^ string_of_typ t0 ^ " -> " ^ string_of_typ t1 ^ ")"
| Meta m -> string_of_meta !m
and string_of_meta : meta_state -> string = function
| Solved t -> string_of_typ t
| Unsolved x -> "?" ^ x
let rec string_of_expr : expr -> string = function
| Num n -> string_of_int n
| Bop (op, e0, e1) -> "(" ^ string_of_expr e0 ^ " " ^ string_of_op op ^ " " ^ string_of_expr e1 ^ ")"
| Var x -> x
| Lam (x, e) -> "(𝜆" ^ x ^ ". " ^ string_of_expr e ^ ")"
| App (e0, e1) -> "(" ^ string_of_expr e0 ^ " " ^ string_of_expr e1 ^ ")"
| Let (x, e, e') -> "(let " ^ x ^ string_of_expr e ^ " = " ^ string_of_expr e' ^ ")"
and string_of_op : bop -> string = function
| Add -> "+"
| Sub -> "-"
| Mul -> "*"
(* some exceptions *)
exception UndefinedVar of nom
exception UnUnifiable of typ * typ
exception OccursFailure
(* fresh name supply *)
module Fresh : sig
val freshi : int ref
val nexti : unit -> int
val fresh : unit -> typ
end = struct
let freshi = ref 0
let nexti () =
let curr = !freshi in
freshi := curr + 1;
curr
let fresh () = Meta (ref (Unsolved ("x" ^ string_of_int (nexti ()))))
end
open Fresh
(* metavar forcing:
before pattern matching on a type, we always force it to follow the "links" *)
let rec force : typ -> typ = function
| Meta m as t ->
begin match !m with
| Solved t -> force t
| _ -> t
end
| t -> t
(* unification *)
let rec occurs (x : nom) (t : typ) : unit =
match force t with
| Meta m ->
begin match !m with
| Unsolved x' when x = x' -> raise OccursFailure
| _ -> ()
end
| Int -> ()
| Arrow (t0, t1) -> occurs x t0; occurs x t1
let rec unify (t0, t1 : typ * typ) : unit =
match force t0, force t1 with
| t0, t1 when t0 = t1 -> ()
| Meta m, t | t, Meta m ->
begin match !m with
| Unsolved x -> occurs x t; m := Solved t
| Solved _ -> failwith "absurd!" (* impossible case, since `force` never returns a solved meta *)
end
| Int, Int -> ()
| Arrow (t0, t1), Arrow (t0', t1') -> unify (t0, t1); unify (t0', t1')
| t0, t1 -> raise (UnUnifiable (t0, t1))
(* type inference itself *)
type ctx = (nom * typ) list
let rec infer (ctx : ctx) : expr -> typ = function
| Num n -> Int
| Bop (_, e0, e1) ->
check ctx e0 Int;
check ctx e1 Int;
Int
| Var x ->
begin match List.assoc_opt x ctx with
| Some t -> t
| None -> raise (UndefinedVar x)
end
| Lam (x, e) ->
let t0 = fresh () in
let t1 = infer ((x, t0) :: ctx) e in
Arrow (t0, t1)
| App (e0, e1) ->
let arg_typ = infer ctx e1 in
let ret_typ = fresh () in
check ctx e0 (Arrow (arg_typ, ret_typ));
ret_typ
| Let (x, e, e') ->
let t = infer ctx e in
infer ((x, t) :: ctx) e'
and check (ctx : ctx) (e : expr) (t : typ) : unit =
let t' = infer ctx e in
try unify (t, t') with
| UnUnifiable (expected, actual) -> failwith ("expected type " ^ string_of_typ expected ^ " but received " ^ string_of_typ actual)
let () = print_endline "hello STLC!"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment