Skip to content

Instantly share code, notes, and snippets.

@pema99
Created January 30, 2022 19:29
Show Gist options
  • Save pema99/dab60ee4248eef2cff5e74e76d672620 to your computer and use it in GitHub Desktop.
Save pema99/dab60ee4248eef2cff5e74e76d672620 to your computer and use it in GitHub Desktop.
// Hindley-Milner type inference, 2022 Pema Malling
// Resources:
// - https://course.ccs.neu.edu/cs4410sp19/lec_type-inference_notes.html
// - http://dev.stephendiehl.com/fun/006_hindley_milner.html#inference-monad
// AST and types
type Lit =
| Int of int
| Bool of bool
type BinOp =
| Add
| Sub
| Mul
| Div
| Eq
type Expr =
| Var of string
| App of Expr * Expr
| Lam of string * Expr
| Let of string * Expr * Expr
| Lit of Lit
| If of Expr * Expr * Expr
| Op of Expr * BinOp * Expr
type Type =
| TVar of string
| TCon of string
| TArr of Type * Type
let tInt = TCon "int"
let tBool = TCon "bool"
// Schemes and environments
type Scheme = string list * Type
type TypeEnv = Map<string, Scheme>
let extend env x s = Map.add x s env
let lookup env x = Map.tryFind x env
let remove env x = Map.remove x env
// Substitution
type Substitution = Map<string, Type>
let compose s1 s2 = Map.fold (fun acc k v -> Map.add k v acc) s1 s2
let rec ftvType (t: Type) : string Set =
match t with
| TCon _ -> Set.empty
| TVar a -> Set.singleton a
| TArr (t1, t2) -> Set.union (ftvType t1) (ftvType t2)
let rec applyType (s: Substitution) (t: Type) : Type =
match t with
| TCon _ -> t
| TVar a -> Map.tryFind a s |> Option.defaultValue t
| TArr (t1, t2) -> TArr (applyType s t1 , applyType s t2)
let ftvScheme (sc: Scheme) : string Set =
let a, t = sc
Set.difference (ftvType t) (Set.ofList a)
let rec applyScheme (s: Substitution) (sc: Scheme) : Scheme =
let a, t = sc
let s' = List.fold (fun acc k -> Map.remove k acc) s a // TODO: Is this right?
(a, applyType s' t)
let ftvEnv (t: TypeEnv) : Set<string> =
let elems = t |> Map.toList |> List.map snd
let ftv = List.fold (fun acc x -> Set.union acc (ftvScheme x)) Set.empty
ftv elems
let applyEnv (s: Substitution) (t: TypeEnv) : TypeEnv =
Map.map (fun _ v -> applyScheme s v) t
// Unification
let occurs (s: string) (t: Type) : bool =
Set.exists ((=) s) (ftvType t)
let rec unify (t1: Type) (t2: Type) : Substitution =
match t1, t2 with
| TVar a, b when not (occurs a b) -> Map.ofList [(a, b)]
| a, TVar b when not (occurs b a) -> Map.ofList [(b, a)]
| TVar a, TVar b when a = b -> Map.empty
| TCon a, TCon b when a = b -> Map.empty
| TArr (l1, r1), TArr (l2, r2) ->
let s1 = unify l1 l2
let s2 = unify (applyType s1 r1) (applyType s1 r2)
compose s2 s1
| _ ->
failwith <| sprintf "Failed to unify types %A and %A" t1 t2
// Instantiation and generalization
let mutable freshCount = 0
let fresh() =
freshCount <- freshCount + 1
TVar <| sprintf "_t%A" freshCount
let instantiate (sc: Scheme) : Type =
let (s, t) = sc
let ss = List.map (fun _ -> fresh()) s
let v = List.zip s ss |> Map.ofList
applyType v t
let generalize (env: TypeEnv) (t: Type) : Scheme =
(Set.toList <| Set.difference (ftvType t) (ftvEnv env), t)
// Type schemes for built in operators
let ops = Map.ofList [
Add, (["a"], TArr (TVar "a", TArr (TVar "a", TVar "a")))
Sub, (["a"], TArr (TVar "a", TArr (TVar "a", TVar "a")))
Mul, (["a"], TArr (TVar "a", TArr (TVar "a", TVar "a")))
Div, (["a"], TArr (TVar "a", TArr (TVar "a", TVar "a")))
Eq, (["a"], TArr (TVar "a", TArr (TVar "a", tBool)))
]
// Inference
let rec infer (env: TypeEnv) (e: Expr) : Substitution * Type =
match e with
| Lit (Int _) -> (Map.empty, tInt)
| Lit (Bool _) -> (Map.empty, tBool)
| Var a ->
match lookup env a with
| Some s -> (Map.empty, instantiate s)
| None -> failwith <| sprintf "Inference failure, use of unbound variable %A" a
| App (f, x) ->
let tv = fresh()
let s1, t1 = infer env f
let s2, t2 = infer (applyEnv s1 env) x
let s3 = unify (applyType s2 t1) (TArr (t2, tv))
(compose s3 (compose s2 s1), applyType s3 tv)
| Lam (x, e) ->
let tv = fresh()
let nenv = extend env x ([], tv)
let s1, t1 = infer nenv e
(s1, TArr (applyType s1 tv, t1))
| Let (x, e1, e2) ->
let s1, t1 = infer env e1
let nenv = applyEnv s1 env
let nt = generalize nenv t1
let s2, t2 = infer (extend env x nt) e2
(compose s1 s2, t2)
| If (cond, tr, fl) ->
let s1, t1 = infer env cond
let s2, t2 = infer env tr
let s3, t3 = infer env fl
let s4 = unify t1 tBool
let s5 = unify t2 t3
(compose s5 (compose s4 (compose s3 (compose s2 s1))), applyType s5 t2)
| Op (l, op, r) ->
let s1, t1 = infer env l
let s2, t2 = infer env r
let tv = fresh()
let scheme = Map.find op ops
let s3 = unify (TArr (t1, TArr (t2, tv))) (instantiate scheme)
(compose s1 (compose s2 s3), applyType s3 tv)
// Pretty printing
let prettyTypeName (i: int) : string =
if i < 26 then string <| 'a' + char i
else sprintf "t%A" i
let renameFresh (t: Type) : Type =
let rec cont t subst count =
match t with
| TCon _ -> t, subst, count
| TArr (l, r) ->
let (r1, subst1, count1) = cont l subst count
let (r2, subst2, count2) = cont r subst1 count1
TArr (r1, r2), subst2, count2
| TVar a ->
match lookup subst a with
| Some v -> TVar v, subst, count
| None ->
let name = prettyTypeName count
let nt = TVar name
nt, Map.ofList [(a, name)], count + 1
let (res, _, _) = cont t Map.empty 0
res
// Tests
let checkTest i e a =
if e = a then
printfn "[%A] Pass." i
else
printfn "[%A] Fail:" i
printfn "\tExpected: %A" e
printfn "\tActual: %A" a
let cases = [
tInt, Lit (Int 5)
tBool, Lit (Bool false)
tInt, Op (Lit (Int 5), Add, Lit (Int 6))
tInt, Let ("c", Lit (Int 5), Op (Var "c", Mul, Var "c"))
TArr (TVar "a", TArr (TVar "a", TVar "a")), Let ("add", Lam ("a", Lam ("b", Op (Var "a", Add, Var "b"))), Var ("add"))
TArr (TCon "bool", TArr (TCon "int", TCon "int")), Lam ("a", Lam("b", If (Var "a", Lit (Int 5), Var "b")))
]
printfn "Running tests..."
cases
|> List.iteri (fun i (t, e) -> checkTest i t (infer Map.empty e |> snd |> renameFresh))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment