Created February 14, 2019 10:36
Simple Type Inference System in Haskell
import qualified Data.Map.Strict as Map
import Data.List
-- This implementation is based on
data Expression
= EInt
Int -- value
| EVar
String -- name
| EFunc
String -- param name
Expression -- body
| EFuncCall
Expression -- function
Expression -- argument
data Type
= TNamed
String -- name of the named type (e.g. String/Int)
| TVar
String -- name of the type variable (e.g. T1)
| TFunc
Type -- input type
Type -- output type
deriving (Show)
-- simply a map of variable names in scope to their types
type Env = Map.Map String Type
data Err
= ErrUnboundVar String
| ErrTypeMisMatch Type Type
| ErrTVarSelfReference String Type
deriving (Show)
infer :: Context -> Expression -> Either Err (Type, Substitution)
infer ctx@(Context _ env) expr =
case expr of
EInt{} ->
Right (TNamed "Int", emptySubstitution)
EVar name ->
case Map.lookup name env of
Just t ->
Right (t, emptySubstitution)
Nothing ->
Left (ErrUnboundVar name)
EFunc paramName body -> do
let (ctx1, newTVar) = createNewTVar ctx
let funcParamType = newTVar
let ctx2 = addNewBindingToContext ctx1 paramName newTVar
(inferredBodyType, subst) <- infer ctx2 body
let inferredFuncParamType = applySubstitutionToType subst funcParamType
Right (TFunc inferredFuncParamType inferredBodyType, subst)
EFuncCall func arg -> do
(funcType, s1) <- infer ctx func
(argType, s2) <- infer ctx arg
let s3 = composeSubst s1 s2
case funcType of
TFunc inputType outputType -> do
s4 <- unify inputType argType
let s5 = composeSubst s3 s4
Right (applySubstitutionToType s5 outputType, s5)
_ ->
Left (ErrTypeMisMatch (TFunc argType (snd $ createNewTVar ctx)) funcType)
data Context
= Context
Int -- next
-- map names of type varaibles to the type assigned to them
type Substitution
= Map.Map String Type
emptySubstitution :: Substitution
emptySubstitution = Map.empty
-- replace the type variables in a type that are
-- present in the given substitution and return the
-- type with those variables with their substituted values
-- eg. Applying the substitution {"a": Bool, "b": Int}
-- to a type (a -> b) will give type (Bool -> Int)
applySubstitutionToType :: Substitution -> Type -> Type
applySubstitutionToType subst type' =
case type' of
TNamed{} ->
TVar name ->
case Map.lookup name subst of
Just t ->
Nothing ->
TFunc from to ->
(applySubstitutionToType subst from)
(applySubstitutionToType subst to)
addNewBindingToContext :: Context -> String -> Type -> Context
addNewBindingToContext (Context nextInt env) varname type' =
Context nextInt (Map.insert varname type' env)
createNewTVar :: Context -> (Context, Type)
createNewTVar (Context nextInt env) =
(Context (nextInt + 1) env, TVar ("T$" ++ show nextInt))
unify :: Type -> Type -> Either Err Substitution
unify t1 t2 =
case t1 of
TNamed name1 ->
unifyTNamed name1 t2
TVar name ->
unifyTVar name t2
TFunc argType bodyType ->
unifyTFunc argType bodyType t2
unifyTFunc :: Type -> Type -> Type -> Either Err Substitution
unifyTFunc argType bodyType t2 =
case t2 of
TFunc argType' bodyType' -> do
subst1 <- unify argType argType'
subst2 <- unify (applySubstitutionToType subst1 bodyType) (applySubstitutionToType subst1 bodyType')
Right (composeSubst subst1 subst2)
TVar name ->
unifyTVar name (TFunc argType bodyType)
TNamed _ ->
Left (ErrTypeMisMatch (TFunc argType bodyType) t2)
Apply substitution of s1 into s1
For example if
s1 = {t1 => Int}
s2 = {t2 => t1}
Then the result will be
s3 = {
t1 => Int
t2 => Int
composeSubst :: Substitution -> Substitution -> Substitution
composeSubst s1 s2 =
let result =
(\subst (key, type') -> Map.insert key (applySubstitutionToType s1 type') subst)
((Map.assocs s2)::[(String, Type)]) in
Map.union s1 result
unifyTVar :: String -> Type -> Either Err Substitution
unifyTVar name1 t2 =
case t2 of
TVar name2 ->
if name1 == name2 then
Right emptySubstitution
Right (Map.insert name1 t2 emptySubstitution)
_ ->
if t2 `contains` name1 then
Left (ErrTVarSelfReference name1 t2)
Right (Map.insert name1 t2 emptySubstitution)
contains :: Type -> String -> Bool
t `contains` tvarname =
case t of
TVar name ->
name == tvarname
TFunc inType outType ->
inType `contains` tvarname || outType `contains` tvarname
_ ->
unifyTNamed :: String -> Type -> Either Err Substitution
unifyTNamed name1 t2 =
let err = ErrTypeMisMatch (TNamed name1) t2 in
case t2 of
TNamed name2 ->
if name1 == name2 then
Right emptySubstitution
Left err
TVar name ->
unifyTVar name (TNamed name1)
_ ->
Left err
main :: IO ()
main = do
let initialEnv = Map.insert "+" (TFunc (TNamed "Int") (TFunc (TNamed "Int") (TNamed "Int"))) (Map.empty :: Env)
let ctx = Context 0 initialEnv
let expr = EFuncCall (EFunc "x" (EFuncCall (EFuncCall (EVar "+") (EVar "x")) (EInt 5))) (EInt 99)
print (infer ctx (expr))
