Skip to content

Instantly share code, notes, and snippets.

@aradarbel10
Created December 6, 2022 15:40
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save aradarbel10/2e4cc8671ae7d366a1edf2da038d1403 to your computer and use it in GitHub Desktop.
Save aradarbel10/2e4cc8671ae7d366a1edf2da038d1403 to your computer and use it in GitHub Desktop.
A minimalistic example of bidirectional type checking for system F
{-# LANGUAGE StrictData, DerivingVia, OverloadedRecordDot #-}
{-
(compiled with GHC 9.4.2)
-}
{-
HEADS UP
this is an example implementation of a non-trivial type system using bidirectional type checking.
it is...
- naive: a real implementation would use more advanced techniques
such as NbE, elaboration to core, debruijn representation, and various optimizations
- minimal: there is no parser or user interface, because I wanted to focus on the algorithm.
and I couldn't be bothered.
- beginner friendly... sort of: you should have some familiarity with polymorphic lambda calculus,
and I assume you know roughly what's the goal of type checking in a compiler.
of course, basic Haskell understanding is required to read this, though I didn't use any super advanced features.
- impractical: lack of features such as metavariables, toplevel scope, recursive definitions, ADTs, etc
make this unusable for real applications, but a good starting point for your own type checkers in actual compilers.
-}
--- The language we want to typecheck: system F extended with lists, booleans, and products
type Name = String
data Tm -- terms
= Var Name -- x
| Ann Tm Tp -- e : t
| Lam Name (Maybe Tp) Tm -- λx:t. e possibly unannotated
| App Tm Tm -- e1 e2
| Tlam Name Tm -- ΛT. e
| Inst Tm Tp -- e [t]
| Let Name (Maybe Tp) Tm Tm -- let x : t = e in e' possibly unannotated
| Pair Tm Tm -- (e1, e2)
| Fst Tm | Snd Tm -- e.1 | e.2
| TrueE | FalseE -- true | false
| ITE Tm Tm Tm -- if e then e1 else e2
| List [Tm] -- {e1, e2, ...}
deriving Show
data Tp -- types
= Tvar Name -- a
| Arrow Tp Tp -- t1 → t2
| Prod Tp Tp -- t1 * t2
| Forall Name Tp -- ∀a. t
| BoolT -- bool
| ListT Tp -- list t
deriving Show
{-
our goal is to take a term (user program) and infer a valid type for it, or reject it with a type error.
however in system F inferring a type for a term is generally undecidable, so the user sometimes needs to provide types.
in this case, we must check the user-provided type is a correct type for the term.
along the way, we also ensure all user-provided types are well-formed.
the algorithm is then built of two mutually recursive functions
infer :: Ctx -> Tm -> Tp -- either returns the type of the term of fails
check :: Ctx -> Tm -> Tp -> () -- either returns unit ("success!") or fails
hence 'bi' in bidirectional
`Ctx` is a list of assumptions we know in the current context.
this includes what types have what variables (x : t),
and which type variables are actually defined in scope.
-}
-- context and operations on it
data Ctx = Ctx { vars :: [(Name, Tp)], typs :: [Name] }
emptyctx :: Ctx
emptyctx = Ctx [] []
-- extend with a var assumption (x : t)
assume :: Ctx -> Name -> Tp -> Ctx
assume ctx nm tp = ctx { vars = (nm, tp) : ctx.vars }
-- extend with a tvar assumption (a)
assumeT :: Ctx -> Name -> Ctx
assumeT ctx nm = ctx { typs = nm : ctx.typs }
-- find the type of a bound variable
lookupT :: Ctx -> Name -> Maybe Tp
lookupT ctx nm = Prelude.lookup nm ctx.vars
-- ensure a variable is indeed a tvar
isTvar :: Ctx -> Name -> Bool
isTvar ctx nm = elem nm ctx.typs
{-
there are two more major opreations we need to define before actually type checking:
Substitution
============
in certain cases we want to take a type T which contains some "free variable" a (variable not bound by a ∀),
and substitute it with some new value S. this is written mathematically as: `T[S/a]` or `T[a ↦ S]`.
for example,
(a → int)[bool/a] == bool → int
we need to keep in mind variable shadowing, which is when two *different* variables share a name. for example
(a → (∀a. a → int))[bool/a] == bool → (∀a. a → int)
the substitution does not touch the inner (bound) `a` because it is not the same `a` as the outer (free) one.
this becomes clear if we apply alpha-substitution on the inner ∀:
a → (∀a. a → int) == a → (∀b. b → int)
Type Equality
=============
type equality can mean different things, depending on our "equational theory".
in more advanced type theories equality of types becomes very complicated to typecheck, but our theory has one simple rule:
alpha equivalence, which means we allow (consistent) renaming of bound variables. eg
∀a. a → a == ∀b. b → b
the process of deciding type equality is often called "conversion checking",
and we say two types are "convertible" to each other.
-}
subs :: Tp -> Name -> Tp -> Tp
-- substitute all free occurances of `nm` in `typ` with `new`. we assume `typ` and `new` are well-formed types
subs typ nm new = case typ of
Tvar nm' -> if nm == nm' -- only substitute if it's the same variable
then new
else typ
Arrow ltyp rtyp -> Arrow (subs ltyp nm new) (subs rtyp nm new)
Prod ltyp rtyp -> Prod (subs ltyp nm new) (subs rtyp nm new)
Forall nm' body -> if nm == nm' -- only subs inside Forall if it doesn't shadow
then typ
else Forall nm' (subs body nm new)
BoolT -> BoolT
ListT t -> ListT (subs t nm new)
-- we define a mostly-structural equality on types
instance Eq Tp where
typ1 == typ2 = case (typ1, typ2) of
(Tvar nm1, Tvar nm2) -> nm1 == nm2
(Arrow ltyp1 rtyp1, Arrow ltyp2 rtyp2) -> ltyp1 == ltyp2 && rtyp1 == rtyp2
(Prod ltyp1 rtyp1, Prod ltyp2 rtyp2) -> ltyp1 == ltyp2 && rtyp1 == rtyp2
(BoolT, BoolT) -> True
(ListT typ1, ListT typ2) -> typ1 == typ2
(Forall nm1 body1, Forall nm2 body2) -> subs body2 nm2 (Tvar nm1) == body1 -- remember alpha equivalence!
(_, _) -> False
-- same as equality but throws an error if unequal
conv :: Tp -> Tp -> ()
conv typ1 typ2 = if typ1 == typ2
then ()
else error $ "unequal types: " ++ show typ1 ++ " =/= " ++ show typ2
{-
at last, type checking. I did my best to explain for each case why I chose it to be in infer or in check.
remember we prefer to infer whatever's possible. the following defines three functions: infer, check, wellFormed.
"the bidi recipe":
- infer eliminations (application, tuple projection, if/then/else)
- check introductions (lambdas, pairs)
the bidi recipe is an informal guide to help us choose what to infer and what to check,
although in general we just try to infer whatever we can and check the rest.
in practice bidi allows type info to "flow though the expression", so most required annotations are in the "toplevel".
the recipe is useful because that way we only need to annotate "redexes", which are expressions that are "immediately reducible" eg
- `(λx.e1) e2` application directly on a lambda
- `(e1, e2).1` projection directly from a pair
redexes rarely appear in real code, hence we rarely need annotations besides the toplevel ones.
basic assumptions/invariants for type checking:
all types in the context are well formed.
all types returned from `infer` are well formed.
all types passed to `check` are well formed.
never trust blindly a user-supplied type.
-}
infer :: Ctx -> Tm -> Tp
-- variables are easy to infer, just look up in the context
infer ctx (Var nm) = case lookupT ctx nm of
Just typ -> typ
Nothing -> error $ "type error! undefined variable " ++ nm
-- when a lambda is annotated we can try to infer its body
infer ctx (Lam nm (Just intyp) body) =
wellFormed ctx intyp `seq`
let outtyp = infer (assume ctx nm intyp) body in
Arrow intyp outtyp
-- application is inferrable as long as we can infer the left term and make sure it's a function type
infer ctx (App tm1 tm2) =
let tp1 = infer ctx tm1 in
case tp1 of
Arrow ltyp1 rtyp1 -> check ctx tm2 ltyp1 `seq` rtyp1
_ -> error $ "type error! can't apply on a term not of function type: " ++ show tm1
-- type lambdas and type application (instantiation) are somewhat parallel to their term-level counterparts
infer ctx (Tlam nm body) =
let bodytyp = infer (assumeT ctx nm) body in
Forall nm bodytyp
infer ctx (Inst tm typ) =
wellFormed ctx typ `seq`
let tmtyp = infer ctx tm in
case tmtyp of
Forall nm body -> subs body nm typ -- instantiate a type scheme (forall) with the given type parameter
_ -> error $ "type error! can't instantiate a term not of forall type: " ++ show tm
-- see `checkLet` below for an explanation
infer ctx (Let nm typ body rest) =
let ctx' = checkLet ctx nm typ body in
infer ctx' rest
-- we infer pairs if possible
infer ctx (Pair tm1 tm2) = Prod (infer ctx tm1) (infer ctx tm2)
-- projections, like other elimination constructs, should be inferrable
infer ctx (Fst tm) =
let tmtyp = infer ctx tm in
case tmtyp of
Prod typ1 typ2 -> typ1
_ -> error $ "type error! can't project from non-pair: " ++ show tm
infer ctx (Snd tm) =
let tmtyp = infer ctx tm in
case tmtyp of
Prod typ1 typ2 -> typ2
_ -> error $ "type error! can't project from non-pair: " ++ show tm
-- literals are super easy to infer
infer ctx TrueE = BoolT
infer ctx FalseE = BoolT
-- when the user gives an annotation we can infer the type that they gave,
-- but we must make sure that type is indeed correct.
-- this is a "direction change" from infer to check.
infer ctx (Ann tm typ) =
check ctx tm typ `seq` typ
infer _ tm = error $ "type error! non-inferrable case: " ++ show tm ++ "\n\t(hint: you might need to add annotations?)"
check :: Ctx -> Tm -> Tp -> ()
-- we can't easily infer unannotated lambdas, but we can check them
check ctx (Lam nm Nothing body) (Arrow ltyp rtyp) =
-- assume the parameter has the given input type, and ensure the lambda's body has the correct output type
check (assume ctx nm ltyp) body rtyp
-- we can check type lambdas
check ctx (Tlam nm body) (Forall nm' typ) =
let typ' = subs typ nm' (Tvar nm) in
check (assumeT ctx nm) body typ'
-- see `checkLet` below for an explanation
check ctx (Let nm typ body rest) bodytyp =
let ctx' = checkLet ctx nm typ body in
check ctx' rest bodytyp
-- pairs are very easy to check
check ctx (Pair tm1 tm2) (Prod typ1 typ2) =
check ctx tm1 typ1 `seq`
check ctx tm2 typ2
-- I decided to just check if/then/else expressions because it's very easy.
-- but sometimes they can be inferred. try thinking:
-- what if only the first branch can be inferred? or only the second one? can you extend the algorithm to handle these?
check ctx (ITE cond tcase fcase) typ =
check ctx cond BoolT `seq`
check ctx tcase typ `seq`
check ctx fcase typ
-- lists are similar to ITE. sometimes they can be inferred, but it's more complicated.
-- think what to do if just one element in the list is inferrable? what if the list is empty?
check ctx (List tms) (ListT typ) =
map (\tm -> check ctx tm typ) tms
`seq` ()
-- the final catch-all case of checking
-- if we can infer something, we can also check it by testing conversion with the inferred type!
-- (thus checkable terms are a superset of inferrable terms)
check ctx tm expected =
let actual = infer ctx tm in
conv expected actual
-- re-usable function to check let bindings.
-- the idea is, `let x = e in e'` is inferrable if and only if `e'` is inferrable,
-- and checkable iff e' is checkable. with this function we abstract the "main logic" away.
-- it returns the context extended with the new definition.
checkLet :: Ctx -> Name -> Maybe Tp -> Tm -> Ctx
checkLet ctx nm Nothing body =
let bodytyp = infer ctx body in
assume ctx nm bodytyp
checkLet ctx nm (Just typ) body =
wellFormed ctx typ `seq`
check ctx body typ `seq`
assume ctx nm typ
-- ensure that a type is well formed, ie all type variables are indeed bound/defined
wellFormed :: Ctx -> Tp -> ()
wellFormed ctx typ = case typ of
Tvar nm -> if isTvar ctx nm
then ()
else error $ "type error! undefined type variable " ++ nm
Arrow ltyp rtyp -> wellFormed ctx ltyp `seq` wellFormed ctx rtyp
Prod ltyp rtyp -> wellFormed ctx ltyp `seq` wellFormed ctx rtyp
Forall nm body -> wellFormed (assumeT ctx nm) body
BoolT -> ()
ListT t -> wellFormed ctx t
-- if you got here, congrats! this is the entire algorithm.
-- let's run some examples:
-- Λλ→∀
main :: IO ()
main = do
-- check ΛA. λx:A. x against ∀A. A → A
print $ check emptyctx
(Tlam "A" (Lam "x" (Just $ Tvar "A") (Var "x")))
(Forall "A" (Arrow (Tvar "A") (Tvar "A")))
-- unannotated version
print $ check emptyctx
(Tlam "A" (Lam "x" Nothing (Var "x")))
(Forall "A" (Arrow (Tvar "A") (Tvar "A")))
-- infer annotated version
print $ infer emptyctx
(Tlam "A" (Lam "x" (Just $ Tvar "A") (Var "x")))
-- higher rank parameter: λid:(∀A. A → A). (id [bool] true, id [list bool] {})
print $ infer emptyctx
(Lam "id" (Just $ Forall "A" (Arrow (Tvar "A") (Tvar "A")))
(Pair
(App (Inst (Var "id") BoolT) TrueE)
(App (Inst (Var "id") (ListT BoolT)) (List []))
)
)
-- conditionals
-- let boolAnd : bool → bool → bool
-- = λb1. λb2. if b1 then b2 else false
-- in boolAnd true false
print $ infer emptyctx
(Let "boolAnd" (Just $ Arrow BoolT (Arrow BoolT BoolT))
(Lam "b1" Nothing (Lam "b2" Nothing
((ITE (Var "b1") (Var "b2") FalseE))))
(App (App (Var "boolAnd") TrueE) FalseE)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment