Skip to content

Instantly share code, notes, and snippets.

@kccqzy
Created July 13, 2024 17:55
Show Gist options
  • Save kccqzy/fa8a8ae12a198b41c6339e8a5c45978a to your computer and use it in GitHub Desktop.
Save kccqzy/fa8a8ae12a198b41c6339e8a5c45978a to your computer and use it in GitHub Desktop.
A toy implementation of Algorithm W for HN readers (modified from https://github.com/wh5a/Algorithm-W-Step-By-Step/blob/master/AlgorithmW.lhs)
import Control.Monad.Except
import Control.Monad.State
import qualified Data.IntMap as IntMap
import qualified Data.IntSet as IntSet
import qualified Data.Map as Map
data Exp
= EVar String
| ELit Lit
| EApp Exp Exp
| EAbs String Exp
| ELet String Exp Exp
| EPlus Exp Exp
| EPlusFun
deriving (Eq, Ord, Show)
data Lit
= LInt Integer
| LBool Bool
deriving (Eq, Ord, Show)
data Type
= TVar Int
| TInt
| TBool
| TFun Type Type
deriving (Eq, Ord, Show)
-- | A polytype is a (qualified) type with a number of quantifiers (foralls) in front of it.
-- For example the identity function has type "forall x. x -> x" which is represented by having
-- the universally quantified "x" in the set. This set is only generated by the @generalize@
-- function.
--
-- Furthermore, PolyTypes can only appear at the top level. High-rank polymorphism is not allowed.
data PolyType = PolyType IntSet.IntSet Type deriving (Show)
-- | A type environment is a map from term variable to its scheme.
newtype TypeEnv = TypeEnv {getTypeEnv :: Map.Map String PolyType}
-- | A substitution is a map from type variable to its type.
newtype Subst = Subst (IntMap.IntMap Type) deriving (Show)
instance Semigroup Subst where
(Subst s1) <> (Subst s2) = Subst (IntMap.map (applySubstToType (Subst s1)) s2 `IntMap.union` s1)
instance Monoid Subst where
mempty = Subst mempty
-- | @ftvFromType@ computes the set of free type variables in a type.
ftvFromType :: Type -> IntSet.IntSet
ftvFromType (TVar n) = IntSet.singleton n
ftvFromType TInt = mempty
ftvFromType TBool = mempty
ftvFromType (TFun t1 t2) = ftvFromType t1 `IntSet.union` ftvFromType t2
-- | @applySubstToType@ applies a substitution to a type.
applySubstToType :: Subst -> Type -> Type
applySubstToType (Subst s) (TVar n) = case IntMap.lookup n s of
Nothing -> TVar n
Just t -> t
applySubstToType s (TFun t1 t2) = TFun (applySubstToType s t1) (applySubstToType s t2)
applySubstToType _ t = t
ftvFromPolyType :: PolyType -> IntSet.IntSet
ftvFromPolyType (PolyType vars t) = ftvFromType t `IntSet.difference` vars
applySubstToPolyType :: Subst -> PolyType -> PolyType
applySubstToPolyType (Subst s) (PolyType vars t) = PolyType vars (applySubstToType (Subst (IntMap.withoutKeys s vars)) t)
ftvFromTypeEnv :: TypeEnv -> IntSet.IntSet
ftvFromTypeEnv (TypeEnv env) = foldMap ftvFromPolyType env
applySubstToTypeEnv :: Subst -> TypeEnv -> TypeEnv
applySubstToTypeEnv s (TypeEnv env) = TypeEnv (applySubstToPolyType s <$> env)
-- | @generalize@ constructs a polytype by finding free variables in the type that are not free variables in the environment and making universally quantified.
generalize :: TypeEnv -> Type -> PolyType
generalize env t = PolyType (ftvFromType t `IntSet.difference` ftvFromTypeEnv env) t
-- | @instantiate@ replaces universally quantified type variables in a type scheme with fresh type variables.
instantiate :: PolyType -> TypeCheck Type
instantiate (PolyType vars t) = do
ns <- traverse (const newTyVar) (IntMap.fromSet (const ()) vars)
return $ applySubstToType (Subst ns) t
newtype TIState = TIState Int
data TIError
= ErrorTypeUnify Type Type
| ErrorOccursCheck Int Type
| ErrorUnboundVariable String
| ErrorContext Exp TIError
deriving (Show)
addErrorContext :: Exp -> TypeCheck a -> TypeCheck a
addErrorContext e action = action `catchError` \err -> throwError (ErrorContext e err)
type TypeCheck a = ExceptT TIError (State TIState) a
runTI :: TypeCheck a -> Either TIError a
runTI t =
evalState (runExceptT t) (TIState 0)
newTyVar :: TypeCheck Type
newTyVar = do
TIState s <- get
put (TIState (s + 1))
return (TVar s)
-- | @typeUnify@ unifies two types.
typeUnify :: Type -> Type -> TypeCheck Subst
typeUnify (TFun l r) (TFun l' r') = do
s1 <- typeUnify l l'
s2 <- typeUnify (applySubstToType s1 r) (applySubstToType s1 r')
return (s1 <> s2)
typeUnify (TVar u) t = varBind u t
typeUnify t (TVar u) = varBind u t
typeUnify TInt TInt = return mempty
typeUnify TBool TBool = return mempty
typeUnify t1 t2 = throwError (ErrorTypeUnify t1 t2)
-- | @varBind@ binds a type variable to a type, but avoids binding that type variable to itself.
-- Also performs the occurs check (infinite type).
varBind :: Int -> Type -> TypeCheck Subst
varBind u t
| t == TVar u = return mempty
| u `IntSet.member` ftvFromType t =
throwError (ErrorOccursCheck u t)
| otherwise = return (Subst (IntMap.singleton u t))
-- | @ti@ performs type inference for an expression. Notably, it returns types not polytypes.
ti :: TypeEnv -> Exp -> TypeCheck (Subst, Type)
ti (TypeEnv env) (EVar n) =
case Map.lookup n env of
Nothing -> throwError (ErrorUnboundVariable n)
Just sigma -> do
t <- instantiate sigma
return (mempty, t)
ti _ (ELit (LInt _)) = pure (mempty, TInt)
ti _ (ELit (LBool _)) = pure (mempty, TBool)
ti env e@(EAbs n body) = addErrorContext e $ do
tv <- newTyVar
let env' = Map.delete n (getTypeEnv env)
env'' = TypeEnv (Map.insert n (PolyType mempty tv) env')
(s1, t1) <- ti env'' body
return (s1, TFun (applySubstToType s1 tv) t1)
ti env e@(EApp e1 e2) = addErrorContext e $ do
tv <- newTyVar
(s1, t1) <- ti env e1
(s2, t2) <- ti (applySubstToTypeEnv s1 env) e2
s3 <- typeUnify (applySubstToType s2 t1) (TFun t2 tv)
return (s3 <> s2 <> s1, applySubstToType s3 tv)
ti env e@(EPlus e1 e2) = addErrorContext e $
-- The built-in plus operator has type Int -> Int -> Int which is represented
-- by the special value EPlusFun. Therefore we can do the same as EApp twice.
ti env (EApp (EApp EPlusFun e1) e2)
-- If we were to inline this, we would find that for the inner EApp, s1 =
-- mempty, t1 = TFun TInt (TFun TInt TInt), (s2, t2) is normally inferred, s3
-- is the result of unifying t1 against (TFun t2 tv) which results in a
-- substitution of t2 -> TInt, tv -> TFun TInt TInt. In the outer EApp, s1 is
-- the substitution from the inner call, t1 = TFun TInt TInt, (s2, t2) is
-- normally inferred, s3 is the result of unifying t1 against (TFun t2 tv)
-- which results in a substitution t2 -> TInt, rv -> TInt.
ti _ EPlusFun = pure (mempty, TFun TInt (TFun TInt TInt))
ti env e@(ELet x e1 e2) = addErrorContext e $ do
(s1, t1) <- ti env e1
let env' = Map.delete x (getTypeEnv env)
t' = generalize (applySubstToTypeEnv s1 env) t1
env'' = TypeEnv (Map.insert x t' env')
(s2, t2) <- ti (applySubstToTypeEnv s1 env'') e2
return (s1 <> s2, t2)
typeInference :: TypeEnv -> Exp -> TypeCheck Type
typeInference env e = do
(s, t) <- ti env e
return (applySubstToType s t)
examples :: [Exp]
examples =
[ EAbs "x" (EVar "x"),
EAbs "f" (EAbs "x" (EApp (EVar "f") (EApp (EVar "f") (EVar "x")))),
ELet
"id"
(EAbs "x" (EVar "x"))
(EVar "id"),
ELet
"id"
(EAbs "x" (EVar "x"))
(EApp (EVar "id") (EVar "id")),
ELet
"id"
(EAbs "x" (ELet "y" (EVar "x") (EVar "y")))
(EApp (EVar "id") (EVar "id")),
ELet
"id"
(EAbs "x" (ELet "y" (EVar "x") (EVar "y")))
(EApp (EApp (EVar "id") (EVar "id")) (ELit (LInt 2))),
ELet
"wrong"
(EAbs "x" (EApp (EVar "x") (EVar "x")))
(EVar "wrong"),
ELet
"wrong2"
(EAbs "x" (EApp (EApp (EVar "x") (EVar "x")) (EVar "x")))
(EVar "wrong2"),
EAbs
"m"
( ELet
"y"
(EVar "m")
( ELet
"x"
(EApp (EVar "y") (ELit (LBool True)))
(EVar "x")
)
),
EApp (ELit (LInt 2)) (ELit (LInt 2)),
ELet "id" (EAbs "x" (EVar "x")) (EApp (EVar "id") (ELit (LInt 2))),
ELet
"omega"
(EApp (EAbs "x" (EApp (EVar "x") (EVar "x"))) (EAbs "x" (EApp (EVar "x") (EVar "x"))))
(EVar "omega"),
EAbs "f" (EAbs "x" (EApp (EVar "f") (EApp (EVar "f") (EVar "x")))),
ELet
"plusOne"
(EAbs "x" (EPlus (ELit (LInt 1)) (EVar "x")))
(ELet "two" (ELit (LInt 2)) (EApp (EVar "plusOne") (EVar "two"))),
ELet
"plusplus"
(EAbs "x" (EAbs "y" (EPlus (EVar "x") (EPlus (EVar "y") (EVar "x")))))
(ELet "two" (ELit (LInt 2)) (EApp (EVar "plusplus") (EVar "two"))),
let f = EAbs "x" (EPlus (ELit (LInt 1)) (EVar "x"))
z = ELit (LInt 0)
in ELet "church2" (EAbs "f" (EAbs "x" (EApp (EVar "f") (EApp (EVar "f") (EVar "x"))))) (EApp (EApp (EVar "church2") f) z)
]
test :: Exp -> IO ()
test e =
case runTI (typeInference (TypeEnv Map.empty) e) of
Left err -> putStrLn $ show e ++ "\n " ++ show err ++ "\n"
Right t -> putStrLn $ show e ++ " :: " ++ show t ++ "\n"
main :: IO ()
main = mapM_ test examples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment