Created
February 14, 2019 10:36
-
-
Save wongjiahau/5786df6e7dc6a8346bf3ed8439b82470 to your computer and use it in GitHub Desktop.
Simple Type Inference System in Haskell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import qualified Data.Map.Strict as Map | |
import Data.List | |
-- This implementation is based on https://medium.com/@dhruvrajvanshi/type-inference-for-beginners-part-1-3e0a5be98a4b | |
data Expression | |
= EInt | |
Int -- value | |
| EVar | |
String -- name | |
| EFunc | |
String -- param name | |
Expression -- body | |
| EFuncCall | |
Expression -- function | |
Expression -- argument | |
data Type | |
= TNamed | |
String -- name of the named type (e.g. String/Int) | |
| TVar | |
String -- name of the type variable (e.g. T1) | |
| TFunc | |
Type -- input type | |
Type -- output type | |
deriving (Show) | |
-- simply a map of variable names in scope to their types | |
type Env = Map.Map String Type | |
data Err | |
= ErrUnboundVar String | |
| ErrTypeMisMatch Type Type | |
| ErrTVarSelfReference String Type | |
deriving (Show) | |
infer :: Context -> Expression -> Either Err (Type, Substitution) | |
infer ctx@(Context _ env) expr = | |
case expr of | |
EInt{} -> | |
Right (TNamed "Int", emptySubstitution) | |
EVar name -> | |
case Map.lookup name env of | |
Just t -> | |
Right (t, emptySubstitution) | |
Nothing -> | |
Left (ErrUnboundVar name) | |
EFunc paramName body -> do | |
let (ctx1, newTVar) = createNewTVar ctx | |
let funcParamType = newTVar | |
let ctx2 = addNewBindingToContext ctx1 paramName newTVar | |
(inferredBodyType, subst) <- infer ctx2 body | |
let inferredFuncParamType = applySubstitutionToType subst funcParamType | |
Right (TFunc inferredFuncParamType inferredBodyType, subst) | |
EFuncCall func arg -> do | |
(funcType, s1) <- infer ctx func | |
(argType, s2) <- infer ctx arg | |
let s3 = composeSubst s1 s2 | |
case funcType of | |
TFunc inputType outputType -> do | |
s4 <- unify inputType argType | |
let s5 = composeSubst s3 s4 | |
Right (applySubstitutionToType s5 outputType, s5) | |
_ -> | |
Left (ErrTypeMisMatch (TFunc argType (snd $ createNewTVar ctx)) funcType) | |
data Context | |
= Context | |
Int -- next | |
Env | |
-- map names of type varaibles to the type assigned to them | |
type Substitution | |
= Map.Map String Type | |
emptySubstitution :: Substitution | |
emptySubstitution = Map.empty | |
-- replace the type variables in a type that are | |
-- present in the given substitution and return the | |
-- type with those variables with their substituted values | |
-- eg. Applying the substitution {"a": Bool, "b": Int} | |
-- to a type (a -> b) will give type (Bool -> Int) | |
applySubstitutionToType :: Substitution -> Type -> Type | |
applySubstitutionToType subst type' = | |
case type' of | |
TNamed{} -> | |
type' | |
TVar name -> | |
case Map.lookup name subst of | |
Just t -> | |
t | |
Nothing -> | |
type' | |
TFunc from to -> | |
TFunc | |
(applySubstitutionToType subst from) | |
(applySubstitutionToType subst to) | |
addNewBindingToContext :: Context -> String -> Type -> Context | |
addNewBindingToContext (Context nextInt env) varname type' = | |
Context nextInt (Map.insert varname type' env) | |
createNewTVar :: Context -> (Context, Type) | |
createNewTVar (Context nextInt env) = | |
(Context (nextInt + 1) env, TVar ("T$" ++ show nextInt)) | |
unify :: Type -> Type -> Either Err Substitution | |
unify t1 t2 = | |
case t1 of | |
TNamed name1 -> | |
unifyTNamed name1 t2 | |
TVar name -> | |
unifyTVar name t2 | |
TFunc argType bodyType -> | |
unifyTFunc argType bodyType t2 | |
unifyTFunc :: Type -> Type -> Type -> Either Err Substitution | |
unifyTFunc argType bodyType t2 = | |
case t2 of | |
TFunc argType' bodyType' -> do | |
subst1 <- unify argType argType' | |
subst2 <- unify (applySubstitutionToType subst1 bodyType) (applySubstitutionToType subst1 bodyType') | |
Right (composeSubst subst1 subst2) | |
TVar name -> | |
unifyTVar name (TFunc argType bodyType) | |
TNamed _ -> | |
Left (ErrTypeMisMatch (TFunc argType bodyType) t2) | |
{- | |
Apply substitution of s1 into s1 | |
For example if | |
s1 = {t1 => Int} | |
s2 = {t2 => t1} | |
Then the result will be | |
s3 = { | |
t1 => Int | |
t2 => Int | |
} | |
-} | |
composeSubst :: Substitution -> Substitution -> Substitution | |
composeSubst s1 s2 = | |
let result = | |
foldl | |
(\subst (key, type') -> Map.insert key (applySubstitutionToType s1 type') subst) | |
emptySubstitution | |
((Map.assocs s2)::[(String, Type)]) in | |
Map.union s1 result | |
unifyTVar :: String -> Type -> Either Err Substitution | |
unifyTVar name1 t2 = | |
case t2 of | |
TVar name2 -> | |
if name1 == name2 then | |
Right emptySubstitution | |
else | |
Right (Map.insert name1 t2 emptySubstitution) | |
_ -> | |
if t2 `contains` name1 then | |
Left (ErrTVarSelfReference name1 t2) | |
else | |
Right (Map.insert name1 t2 emptySubstitution) | |
contains :: Type -> String -> Bool | |
t `contains` tvarname = | |
case t of | |
TVar name -> | |
name == tvarname | |
TFunc inType outType -> | |
inType `contains` tvarname || outType `contains` tvarname | |
_ -> | |
False | |
unifyTNamed :: String -> Type -> Either Err Substitution | |
unifyTNamed name1 t2 = | |
let err = ErrTypeMisMatch (TNamed name1) t2 in | |
case t2 of | |
TNamed name2 -> | |
if name1 == name2 then | |
Right emptySubstitution | |
else | |
Left err | |
TVar name -> | |
unifyTVar name (TNamed name1) | |
_ -> | |
Left err | |
main :: IO () | |
main = do | |
let initialEnv = Map.insert "+" (TFunc (TNamed "Int") (TFunc (TNamed "Int") (TNamed "Int"))) (Map.empty :: Env) | |
let ctx = Context 0 initialEnv | |
let expr = EFuncCall (EFunc "x" (EFuncCall (EFuncCall (EVar "+") (EVar "x")) (EInt 5))) (EInt 99) | |
print (infer ctx (expr)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment