Created
January 30, 2022 19:29
-
-
Save pema99/dab60ee4248eef2cff5e74e76d672620 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Hindley-Milner type inference, 2022 Pema Malling | |
// Resources: | |
// - https://course.ccs.neu.edu/cs4410sp19/lec_type-inference_notes.html | |
// - http://dev.stephendiehl.com/fun/006_hindley_milner.html#inference-monad | |
// AST and types | |
type Lit = | |
| Int of int | |
| Bool of bool | |
type BinOp = | |
| Add | |
| Sub | |
| Mul | |
| Div | |
| Eq | |
type Expr = | |
| Var of string | |
| App of Expr * Expr | |
| Lam of string * Expr | |
| Let of string * Expr * Expr | |
| Lit of Lit | |
| If of Expr * Expr * Expr | |
| Op of Expr * BinOp * Expr | |
type Type = | |
| TVar of string | |
| TCon of string | |
| TArr of Type * Type | |
let tInt = TCon "int" | |
let tBool = TCon "bool" | |
// Schemes and environments | |
type Scheme = string list * Type | |
type TypeEnv = Map<string, Scheme> | |
let extend env x s = Map.add x s env | |
let lookup env x = Map.tryFind x env | |
let remove env x = Map.remove x env | |
// Substitution | |
type Substitution = Map<string, Type> | |
let compose s1 s2 = Map.fold (fun acc k v -> Map.add k v acc) s1 s2 | |
let rec ftvType (t: Type) : string Set = | |
match t with | |
| TCon _ -> Set.empty | |
| TVar a -> Set.singleton a | |
| TArr (t1, t2) -> Set.union (ftvType t1) (ftvType t2) | |
let rec applyType (s: Substitution) (t: Type) : Type = | |
match t with | |
| TCon _ -> t | |
| TVar a -> Map.tryFind a s |> Option.defaultValue t | |
| TArr (t1, t2) -> TArr (applyType s t1 , applyType s t2) | |
let ftvScheme (sc: Scheme) : string Set = | |
let a, t = sc | |
Set.difference (ftvType t) (Set.ofList a) | |
let rec applyScheme (s: Substitution) (sc: Scheme) : Scheme = | |
let a, t = sc | |
let s' = List.fold (fun acc k -> Map.remove k acc) s a // TODO: Is this right? | |
(a, applyType s' t) | |
let ftvEnv (t: TypeEnv) : Set<string> = | |
let elems = t |> Map.toList |> List.map snd | |
let ftv = List.fold (fun acc x -> Set.union acc (ftvScheme x)) Set.empty | |
ftv elems | |
let applyEnv (s: Substitution) (t: TypeEnv) : TypeEnv = | |
Map.map (fun _ v -> applyScheme s v) t | |
// Unification | |
let occurs (s: string) (t: Type) : bool = | |
Set.exists ((=) s) (ftvType t) | |
let rec unify (t1: Type) (t2: Type) : Substitution = | |
match t1, t2 with | |
| TVar a, b when not (occurs a b) -> Map.ofList [(a, b)] | |
| a, TVar b when not (occurs b a) -> Map.ofList [(b, a)] | |
| TVar a, TVar b when a = b -> Map.empty | |
| TCon a, TCon b when a = b -> Map.empty | |
| TArr (l1, r1), TArr (l2, r2) -> | |
let s1 = unify l1 l2 | |
let s2 = unify (applyType s1 r1) (applyType s1 r2) | |
compose s2 s1 | |
| _ -> | |
failwith <| sprintf "Failed to unify types %A and %A" t1 t2 | |
// Instantiation and generalization | |
let mutable freshCount = 0 | |
let fresh() = | |
freshCount <- freshCount + 1 | |
TVar <| sprintf "_t%A" freshCount | |
let instantiate (sc: Scheme) : Type = | |
let (s, t) = sc | |
let ss = List.map (fun _ -> fresh()) s | |
let v = List.zip s ss |> Map.ofList | |
applyType v t | |
let generalize (env: TypeEnv) (t: Type) : Scheme = | |
(Set.toList <| Set.difference (ftvType t) (ftvEnv env), t) | |
// Type schemes for built in operators | |
let ops = Map.ofList [ | |
Add, (["a"], TArr (TVar "a", TArr (TVar "a", TVar "a"))) | |
Sub, (["a"], TArr (TVar "a", TArr (TVar "a", TVar "a"))) | |
Mul, (["a"], TArr (TVar "a", TArr (TVar "a", TVar "a"))) | |
Div, (["a"], TArr (TVar "a", TArr (TVar "a", TVar "a"))) | |
Eq, (["a"], TArr (TVar "a", TArr (TVar "a", tBool))) | |
] | |
// Inference | |
let rec infer (env: TypeEnv) (e: Expr) : Substitution * Type = | |
match e with | |
| Lit (Int _) -> (Map.empty, tInt) | |
| Lit (Bool _) -> (Map.empty, tBool) | |
| Var a -> | |
match lookup env a with | |
| Some s -> (Map.empty, instantiate s) | |
| None -> failwith <| sprintf "Inference failure, use of unbound variable %A" a | |
| App (f, x) -> | |
let tv = fresh() | |
let s1, t1 = infer env f | |
let s2, t2 = infer (applyEnv s1 env) x | |
let s3 = unify (applyType s2 t1) (TArr (t2, tv)) | |
(compose s3 (compose s2 s1), applyType s3 tv) | |
| Lam (x, e) -> | |
let tv = fresh() | |
let nenv = extend env x ([], tv) | |
let s1, t1 = infer nenv e | |
(s1, TArr (applyType s1 tv, t1)) | |
| Let (x, e1, e2) -> | |
let s1, t1 = infer env e1 | |
let nenv = applyEnv s1 env | |
let nt = generalize nenv t1 | |
let s2, t2 = infer (extend env x nt) e2 | |
(compose s1 s2, t2) | |
| If (cond, tr, fl) -> | |
let s1, t1 = infer env cond | |
let s2, t2 = infer env tr | |
let s3, t3 = infer env fl | |
let s4 = unify t1 tBool | |
let s5 = unify t2 t3 | |
(compose s5 (compose s4 (compose s3 (compose s2 s1))), applyType s5 t2) | |
| Op (l, op, r) -> | |
let s1, t1 = infer env l | |
let s2, t2 = infer env r | |
let tv = fresh() | |
let scheme = Map.find op ops | |
let s3 = unify (TArr (t1, TArr (t2, tv))) (instantiate scheme) | |
(compose s1 (compose s2 s3), applyType s3 tv) | |
// Pretty printing | |
let prettyTypeName (i: int) : string = | |
if i < 26 then string <| 'a' + char i | |
else sprintf "t%A" i | |
let renameFresh (t: Type) : Type = | |
let rec cont t subst count = | |
match t with | |
| TCon _ -> t, subst, count | |
| TArr (l, r) -> | |
let (r1, subst1, count1) = cont l subst count | |
let (r2, subst2, count2) = cont r subst1 count1 | |
TArr (r1, r2), subst2, count2 | |
| TVar a -> | |
match lookup subst a with | |
| Some v -> TVar v, subst, count | |
| None -> | |
let name = prettyTypeName count | |
let nt = TVar name | |
nt, Map.ofList [(a, name)], count + 1 | |
let (res, _, _) = cont t Map.empty 0 | |
res | |
// Tests | |
let checkTest i e a = | |
if e = a then | |
printfn "[%A] Pass." i | |
else | |
printfn "[%A] Fail:" i | |
printfn "\tExpected: %A" e | |
printfn "\tActual: %A" a | |
let cases = [ | |
tInt, Lit (Int 5) | |
tBool, Lit (Bool false) | |
tInt, Op (Lit (Int 5), Add, Lit (Int 6)) | |
tInt, Let ("c", Lit (Int 5), Op (Var "c", Mul, Var "c")) | |
TArr (TVar "a", TArr (TVar "a", TVar "a")), Let ("add", Lam ("a", Lam ("b", Op (Var "a", Add, Var "b"))), Var ("add")) | |
TArr (TCon "bool", TArr (TCon "int", TCon "int")), Lam ("a", Lam("b", If (Var "a", Lit (Int 5), Var "b"))) | |
] | |
printfn "Running tests..." | |
cases | |
|> List.iteri (fun i (t, e) -> checkTest i t (infer Map.empty e |> snd |> renameFresh)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment