Skip to content

Instantly share code, notes, and snippets.

@joom
Created December 13, 2017 20:48
Show Gist options
  • Save joom/2b13fa7625bc95b8b29f4cd8e12f1148 to your computer and use it in GitHub Desktop.
Save joom/2b13fa7625bc95b8b29f4cd8e12f1148 to your computer and use it in GitHub Desktop.
Hindley-Milner type inference in Idris
module Main
import Control.Monad.State
import Data.SortedMap
import Data.SortedSet
%default covering
%access public export
paren : String -> String
paren x = "(" ++ x ++ ")"
green : String -> String
green x = "\ESC[32m" ++ x ++ "\ESC[39m"
blue : String -> String
blue x = "\ESC[34m" ++ x ++ "\ESC[39m"
Name : Type
Name = String
data Term = TmVar Name | TmUnit | TmApp Term Term | TmAbs String Term
implementation Show Term where
show (TmVar x) = x
show TmUnit = "()"
show (TmApp t1 t2) = paren (show t1 ++ " " ++ show t2)
show (TmAbs x t) = paren ("λ" ++ x ++ "." ++ show t)
data Ty = TyVar Name | TyUnit | TyArr Ty Ty
implementation Eq Ty where
(TyVar x) == (TyVar y) = x == y
TyUnit == TyUnit = True
(TyArr a1 a2) == (TyArr b1 b2) = a1 == b1 && a2 == b2
_ == _ = False
implementation Show Ty where
show TyUnit = "()"
show (TyVar x) = x
show (TyArr a b) = paren $ show a ++ " -> " ++ show b
implementation Ord Ty where
compare (TyVar x) (TyVar y) = compare x y
compare (TyVar _) _ = LT
compare TyUnit TyUnit = EQ
compare TyUnit (TyVar _) = GT
compare TyUnit _ = LT
compare (TyArr a1 b1) (TyArr a2 b2) =
if a1 == a2 then compare b1 b2 else compare a1 a2
compare (TyArr _ _) (TyVar _) = GT
compare (TyArr _ _) TyUnit = GT
||| Our monad stack.
||| We are using `Int` as a state instead of `Nat`
||| because of the hacky `intToStr` function.
Infer : Type -> Type
Infer = StateT Int Maybe
||| A hacky way to convert any integer to a string.
||| It initially tries to get a single character string,
||| but if the number is too high then it adds some number afterwards.
||| Does not generate good variable names for negative numbers.
intToStr : Int -> String
intToStr i = assert_total $ let d = i `div` 25 in
pack [chr (97 + (i `mod` 25))] ++
(if d == 0 then "" else show d)
fresh : Infer Ty
fresh = do modify (the (Int -> Int) (+1))
pure $ TyVar $ intToStr (!get)
data Scheme = Mono Ty | Forall Name Scheme
Subst : Type
Subst = SortedMap Name Ty
interface TypeVars a where
allVars : a -> SortedSet Name
freeVars : a -> SortedSet Name
subst : Subst -> a -> a
implementation TypeVars Ty where
allVars (TyVar a) = insert a empty
allVars (TyArr t t') = union (allVars t) (allVars t')
allVars _ = empty
freeVars = allVars
subst s v@(TyVar a) = fromMaybe v $ lookup a s
subst s (TyArr t t') = TyArr (subst s t) (subst s t')
subst _ t = t
implementation TypeVars Scheme where
allVars (Mono t) = allVars t
allVars (Forall a t) = insert a $ allVars t
freeVars (Mono t) = freeVars t
freeVars (Forall a t) = delete a $ freeVars t
subst s (Mono t) = Mono $ subst s t
subst s (Forall a t) = Forall a $ subst (delete a s) t
Context : Type
Context = SortedMap Name Scheme
compose : Subst -> Subst -> Subst
compose s s' = mergeLeft (map (subst s) s') s
implementation TypeVars Context where
allVars m = foldl union empty $ values (map allVars m)
freeVars m = foldl union empty $ values (map freeVars m)
subst s = map (subst s)
instantiate : Scheme -> Infer Ty
instantiate t = replaceFree t <$> foldlM update empty boundVars
where
boundVars : List Name
boundVars = Data.SortedSet.toList $ difference (allVars t) (freeVars t)
replaceFree : Scheme -> Subst -> Ty
replaceFree (Mono t') s = subst s t'
replaceFree (Forall _ t') s = replaceFree t' s
update : Subst -> Name -> Infer Subst
update acc a = pure $ insert a !fresh acc
bindVar : Name -> Ty -> Infer Subst
bindVar a t = if t == TyVar a then pure empty
else if contains a (freeVars t) then lift Nothing
else pure $ insert a t empty
||| TODO: prove that it is total with the gas tank idiom
unify : Ty -> Ty -> Infer Subst
unify TyUnit TyUnit = empty
unify (TyVar a) t = bindVar a t
unify t (TyVar a) = bindVar a t
unify (TyArr s s') (TyArr t t') = do
s1 <- unify s t
s2 <- unify (subst s1 s') (subst s1 t')
pure $ compose s1 s2
unify _ _ = lift Nothing
infer : Context -> Term -> Infer (Subst, Ty)
infer c (TmVar i) =
do sc <- lift $ lookup i c
tau <- instantiate sc
pure (empty, tau)
infer _ TmUnit = pure (empty, TyUnit)
infer c (TmApp e e') =
do (s1, t1) <- infer c e
(s2, t2) <- infer (subst s1 c) e'
t' <- fresh
s3 <- unify (subst s2 t1) (TyArr t2 t')
pure (compose s3 $ compose s2 s1, subst s3 t')
infer c (TmAbs v e) =
do t' <- fresh
let c' = insert v (Mono t') $ delete v c
(s1, t1) <- infer c' e
pure (s1, TyArr (subst s1 t') t1)
typeInf : Term -> Maybe Ty
typeInf t = (snd . fst) <$> (runStateT (infer empty t) 0)
printResult : Term -> IO ()
printResult t = do
putStr $ "The term " ++ green (show t)
case typeInf t of
Nothing => putStr " is not typeable"
Just tau => putStr $ " has the type " ++ blue (show tau)
putStr "\n"
main : IO ()
main = for_ [ (TmAbs "x" $ TmAbs "y" $ TmApp (TmVar "x") (TmVar "y"))
, (TmAbs "x" $ TmAbs "y" $ TmApp (TmVar "y") (TmVar "x"))
, (TmAbs "x" $ TmApp (TmVar "x") (TmVar "x"))
, (TmAbs "x" $ TmApp (TmVar "x") TmUnit)
] printResult
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment