Skip to content

Instantly share code, notes, and snippets.

@jozefg
Last active February 9, 2022 11:59
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jozefg/719c7e12ed71791b4100 to your computer and use it in GitHub Desktop.
Save jozefg/719c7e12ed71791b4100 to your computer and use it in GitHub Desktop.
A demonstration of type inference in SML
signature TYPEINFER =
sig
type tvar = int
datatype monotype = TBool
| TArr of monotype * monotype
| TVar of tvar
datatype polytype = PolyType of int list * monotype
datatype exp = True
| False
| Var of int
| App of exp * exp
| Let of exp * exp
| Fn of exp
| If of exp * exp * exp
val infer : exp -> polytype
end
(* An example of type inference for a tiny ML-like language.
* Our language includes
*
* - Functions
* - Booleans
* - If expressions
* - Let abstraction
*
* This means the types are function types [t -> t] and booleans
* [bool] as well as type variables for polymorphic types
*
* To simplify the process of representing variables, we use something
* called DeBruijn indices (pronounced "DeBrown"). With these a
* variable is a number that counts how many binders it is away from
* where it was bound
*
* So if we had
*
* __________
* | |
* fn => fn => 1 + 0
* |________|
*
* The 0 refers to the inner binder and the 1 to the binder 1 binder
* away. Here's the wikipedia article:
* http://www.wikiwand.com/en/De_Bruijn_index
*
* For types we'll just use numbers to indicate binding. Since we
* never have nested scopes (all the binders for a type are fully
* shifted to the left) we don't need a clever scheme like DeBruijn
* indices. We don't allow the user to annotate types: we infer
* absolutely everything.
*)
structure TypeInfer :> TYPEINFER =
struct
fun dedup [] = []
| dedup (x :: xs) =
if List.exists (fn y => x = y) xs
then dedup xs
else x :: dedup xs
(* As mentioned, type variables are just integers: They're globally unique *)
type tvar = int
(* Cough, cough, ignore this bit. It lets us generate fresh type
* variables but uses some features of SML we haven't talked about yet *)
local val freshSource = ref 0 in
fun fresh () : tvar =
!freshSource before freshSource := !freshSource + 1
end
(* A normal (not polymorphic) type *)
datatype monotype = TBool
| TArr of monotype * monotype
| TVar of tvar
(* A polymorphic type which has binds a list of type variables *)
datatype polytype = PolyType of int list * monotype
(* Our definition of expressions in our language *)
datatype exp = True
| False
| Var of int
| App of exp * exp
| Let of exp * exp
| Fn of exp
| If of exp * exp * exp
(* A piece of information we know about a particular variable. We
either know it has a polytype or a monotype *)
datatype info = PolyTypeVar of polytype
| MonoTypeVar of monotype
(* A list of information. The ith entry is about the DeBruijn variable i *)
type context = info list
(* Substitute some type var for another type *)
fun subst ty' var ty =
case ty of
TVar var' => if var = var' then ty' else TVar var'
| TArr (l, r) => TArr (subst ty' var l, subst ty' var r)
| TBool => TBool
(* Collect a list of all the variables in a type *)
fun freeVars t =
case t of
TVar v => [v]
| TArr (l, r) => freeVars l @ freeVars r
| TBool => []
(* Any polytype can construct a new monotype with the polymorphic
* variables replaced by fresh weakly polymorphic variables *)
fun mintNewMonoType (PolyType (ls, ty)) =
foldl (fn (v, t) => subst (TVar (fresh ())) v t) ty ls
fun generalizeMonoType ctx ty =
let fun notMem xs x = List.all (fn y => x <> y) xs
fun free (MonoTypeVar m) = freeVars m
| free (PolyTypeVar (PolyType (bs, m))) =
List.filter (notMem bs) (freeVars m)
val ctxVars = List.concat (List.map free ctx)
val polyVars = List.filter (notMem ctxVars) (freeVars ty)
in PolyType (dedup polyVars, ty) end
exception UnboundVar of int
(* Lookup a variable in a list. If we don't find the variable then we throw
* an UnboundVar exception *)
fun lookupVar var ctx =
case List.nth (ctx, var) handle Subscript => raise UnboundVar var of
PolyTypeVar pty => mintNewMonoType pty
| MonoTypeVar mty => mty
(* A big part of the type checker is the generation of
* "constraints". These are assertions that some type is equivalent to
* another. These are represented with just a tuple of types which is to be
* read "The left component unifies with the right component".
*
* Once we solve these constrains we get a [sol] which is a list of
* variables to the monotypes the become.
*)
type constr = (monotype * monotype)
type sol = (tvar * monotype) list
exception UnificationError of constr
(* Substitute a type for a type variable in a list of constraints *)
fun substConstrs ty var (cs : constr list) : constr list =
map (fn (l, r) => (subst ty var l, subst ty var r)) cs
(* Given a solution and a type, apply the solution by rewriting each
* variable with what the solution says it's equivalent to.
*)
fun applySol sol ty =
foldl (fn ((v, ty), ty') => subst ty v ty') ty sol
fun applySolCxt sol cxt =
let fun applyInfo i =
case i of
PolyTypeVar (PolyType (bs, m)) =>
PolyTypeVar (PolyType (bs, (applySol sol m)))
| MonoTypeVar m => MonoTypeVar (applySol sol m)
in map applyInfo cxt end
(* Given a solution, add a new member to the solution ensuring that
* the rhs of the new part of the solution is normal with respect to the
* rest of the solution.
*)
fun addSol v ty sol = (v, applySol sol ty) :: sol
(* Returns true if a variable occurs in the type *)
fun occursIn v ty = List.exists (fn v' => v = v') (freeVars ty)
fun unify ([] : constr list) : sol = []
| unify (c :: constrs) =
case c of
(TBool, TBool) => unify constrs
| (TVar i, TVar j) =>
if i = j
then unify constrs
else addSol i (TVar j) (unify (substConstrs (TVar j) i constrs))
| ((TVar i, ty) | (ty, TVar i)) =>
if occursIn i ty
then raise UnificationError c
else addSol i ty (unify (substConstrs ty i constrs))
| (TArr (l, r), TArr (l', r')) =>
unify ((l, l') :: (r, r') :: constrs)
| _ => raise UnificationError c
fun <+> (sol1, sol2) =
let fun inSol2 v = List.all (fn (v', _) => v <> v') sol2
val sol1' = List.filter (fn (v, _) => inSol2 v) sol1
in
map (fn (v, ty) => (v, applySol sol1 ty)) sol2 @ sol1'
end
infixr 3 <+>
(* Generate all the constraints and solve them for a given expression *)
fun constrain ctx True = (TBool, [])
| constrain ctx False = (TBool, [])
| constrain ctx (Var i) = (lookupVar i ctx, [])
| constrain ctx (Fn body) =
let val argTy = TVar (fresh ())
val (rTy, sol) = constrain (MonoTypeVar argTy :: ctx) body
in (TArr (applySol sol argTy, rTy), sol) end
| constrain ctx (If (i, t, e)) =
let val (iTy, sol1) = constrain ctx i
val (tTy, sol2) = constrain (applySolCxt sol1 ctx) t
val (eTy, sol3) = constrain (applySolCxt (sol1 <+> sol2) ctx) e
in
(tTy, sol1 <+> sol2 <+> sol3 <+> unify [(iTy, TBool), (tTy, eTy)])
end
| constrain ctx (App (l, r)) =
let val (domTy, ranTy) = (TVar (fresh ()), TVar (fresh ()))
val (funTy, sol1) = constrain ctx l
val (argTy, sol2) = constrain (applySolCxt sol1 ctx) r
val sol = unify [(funTy, TArr (domTy, ranTy)), (argTy, domTy)]
<+> sol1
<+> sol2
in (ranTy, sol) end
| constrain ctx (Let (e, body)) =
let val (eTy, sol1) = constrain ctx e
val ctx' = applySolCxt sol1 ctx
val eTy' = generalizeMonoType ctx' (applySol sol1 eTy)
val (rTy, sol2) = constrain (PolyTypeVar eTy' :: ctx') body
in (rTy, sol1 <+> sol2) end
(* Finally, infer and solve all the constraints for a type,
* generating a final, inferred, polytype *)
fun infer e =
let val (ty, sol) = constrain [] e
in generalizeMonoType [] (applySol sol ty) end
end
Library
signature TYPEINFER
structure TypeInfer
is
infer.sig
infer.sml
$/basis.cm

An Explanation of This Code

OK, so now that I've wowed you with my awesome code, let's talk about what's going on. Type inference breaks down into essentially 2 components

  1. Constraint Generation
  2. Unification

We inspect the program we're trying to infer a type for and generate a bunch of statements (constraints) which are of the form

This type is equal to this type

These types have "unification variables" in them. These aren't normal ML type variables. They're generated by the compiler, for the compiler, and will eventually be filled in with either

  1. A rigid polymorphic variable
  2. A normal concrete type

For example, if we're looking at the expression

   f a

We first just say that f : 'f where 'f is one of those unification variables I mentioned. Next we say that a : 'a. Since we're apply f to a we can generate the constraints that

'f ~ 'x -> 'y
'a ~ 'x

We then unify these constraints to produce f : 'a -> 'x and a : 'a. We'd then using the surrounding constraints to produce more information about what exactly 'a and 'x might be.

Now onto some specifics

Constraint Generation

Unification

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment