Created
December 13, 2017 20:48
-
-
Save joom/2b13fa7625bc95b8b29f4cd8e12f1148 to your computer and use it in GitHub Desktop.
Hindley-Milner type inference in Idris
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Main | |
import Control.Monad.State | |
import Data.SortedMap | |
import Data.SortedSet | |
%default covering | |
%access public export | |
paren : String -> String | |
paren x = "(" ++ x ++ ")" | |
green : String -> String | |
green x = "\ESC[32m" ++ x ++ "\ESC[39m" | |
blue : String -> String | |
blue x = "\ESC[34m" ++ x ++ "\ESC[39m" | |
Name : Type | |
Name = String | |
data Term = TmVar Name | TmUnit | TmApp Term Term | TmAbs String Term | |
implementation Show Term where | |
show (TmVar x) = x | |
show TmUnit = "()" | |
show (TmApp t1 t2) = paren (show t1 ++ " " ++ show t2) | |
show (TmAbs x t) = paren ("λ" ++ x ++ "." ++ show t) | |
data Ty = TyVar Name | TyUnit | TyArr Ty Ty | |
implementation Eq Ty where | |
(TyVar x) == (TyVar y) = x == y | |
TyUnit == TyUnit = True | |
(TyArr a1 a2) == (TyArr b1 b2) = a1 == b1 && a2 == b2 | |
_ == _ = False | |
implementation Show Ty where | |
show TyUnit = "()" | |
show (TyVar x) = x | |
show (TyArr a b) = paren $ show a ++ " -> " ++ show b | |
implementation Ord Ty where | |
compare (TyVar x) (TyVar y) = compare x y | |
compare (TyVar _) _ = LT | |
compare TyUnit TyUnit = EQ | |
compare TyUnit (TyVar _) = GT | |
compare TyUnit _ = LT | |
compare (TyArr a1 b1) (TyArr a2 b2) = | |
if a1 == a2 then compare b1 b2 else compare a1 a2 | |
compare (TyArr _ _) (TyVar _) = GT | |
compare (TyArr _ _) TyUnit = GT | |
||| Our monad stack. | |
||| We are using `Int` as a state instead of `Nat` | |
||| because of the hacky `intToStr` function. | |
Infer : Type -> Type | |
Infer = StateT Int Maybe | |
||| A hacky way to convert any integer to a string. | |
||| It initially tries to get a single character string, | |
||| but if the number is too high then it adds some number afterwards. | |
||| Does not generate good variable names for negative numbers. | |
intToStr : Int -> String | |
intToStr i = assert_total $ let d = i `div` 25 in | |
pack [chr (97 + (i `mod` 25))] ++ | |
(if d == 0 then "" else show d) | |
fresh : Infer Ty | |
fresh = do modify (the (Int -> Int) (+1)) | |
pure $ TyVar $ intToStr (!get) | |
data Scheme = Mono Ty | Forall Name Scheme | |
Subst : Type | |
Subst = SortedMap Name Ty | |
interface TypeVars a where | |
allVars : a -> SortedSet Name | |
freeVars : a -> SortedSet Name | |
subst : Subst -> a -> a | |
implementation TypeVars Ty where | |
allVars (TyVar a) = insert a empty | |
allVars (TyArr t t') = union (allVars t) (allVars t') | |
allVars _ = empty | |
freeVars = allVars | |
subst s v@(TyVar a) = fromMaybe v $ lookup a s | |
subst s (TyArr t t') = TyArr (subst s t) (subst s t') | |
subst _ t = t | |
implementation TypeVars Scheme where | |
allVars (Mono t) = allVars t | |
allVars (Forall a t) = insert a $ allVars t | |
freeVars (Mono t) = freeVars t | |
freeVars (Forall a t) = delete a $ freeVars t | |
subst s (Mono t) = Mono $ subst s t | |
subst s (Forall a t) = Forall a $ subst (delete a s) t | |
Context : Type | |
Context = SortedMap Name Scheme | |
compose : Subst -> Subst -> Subst | |
compose s s' = mergeLeft (map (subst s) s') s | |
implementation TypeVars Context where | |
allVars m = foldl union empty $ values (map allVars m) | |
freeVars m = foldl union empty $ values (map freeVars m) | |
subst s = map (subst s) | |
instantiate : Scheme -> Infer Ty | |
instantiate t = replaceFree t <$> foldlM update empty boundVars | |
where | |
boundVars : List Name | |
boundVars = Data.SortedSet.toList $ difference (allVars t) (freeVars t) | |
replaceFree : Scheme -> Subst -> Ty | |
replaceFree (Mono t') s = subst s t' | |
replaceFree (Forall _ t') s = replaceFree t' s | |
update : Subst -> Name -> Infer Subst | |
update acc a = pure $ insert a !fresh acc | |
bindVar : Name -> Ty -> Infer Subst | |
bindVar a t = if t == TyVar a then pure empty | |
else if contains a (freeVars t) then lift Nothing | |
else pure $ insert a t empty | |
||| TODO: prove that it is total with the gas tank idiom | |
unify : Ty -> Ty -> Infer Subst | |
unify TyUnit TyUnit = empty | |
unify (TyVar a) t = bindVar a t | |
unify t (TyVar a) = bindVar a t | |
unify (TyArr s s') (TyArr t t') = do | |
s1 <- unify s t | |
s2 <- unify (subst s1 s') (subst s1 t') | |
pure $ compose s1 s2 | |
unify _ _ = lift Nothing | |
infer : Context -> Term -> Infer (Subst, Ty) | |
infer c (TmVar i) = | |
do sc <- lift $ lookup i c | |
tau <- instantiate sc | |
pure (empty, tau) | |
infer _ TmUnit = pure (empty, TyUnit) | |
infer c (TmApp e e') = | |
do (s1, t1) <- infer c e | |
(s2, t2) <- infer (subst s1 c) e' | |
t' <- fresh | |
s3 <- unify (subst s2 t1) (TyArr t2 t') | |
pure (compose s3 $ compose s2 s1, subst s3 t') | |
infer c (TmAbs v e) = | |
do t' <- fresh | |
let c' = insert v (Mono t') $ delete v c | |
(s1, t1) <- infer c' e | |
pure (s1, TyArr (subst s1 t') t1) | |
typeInf : Term -> Maybe Ty | |
typeInf t = (snd . fst) <$> (runStateT (infer empty t) 0) | |
printResult : Term -> IO () | |
printResult t = do | |
putStr $ "The term " ++ green (show t) | |
case typeInf t of | |
Nothing => putStr " is not typeable" | |
Just tau => putStr $ " has the type " ++ blue (show tau) | |
putStr "\n" | |
main : IO () | |
main = for_ [ (TmAbs "x" $ TmAbs "y" $ TmApp (TmVar "x") (TmVar "y")) | |
, (TmAbs "x" $ TmAbs "y" $ TmApp (TmVar "y") (TmVar "x")) | |
, (TmAbs "x" $ TmApp (TmVar "x") (TmVar "x")) | |
, (TmAbs "x" $ TmApp (TmVar "x") TmUnit) | |
] printResult |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment