Skip to content

Instantly share code, notes, and snippets.

@nathaniel-may
Created March 25, 2023 23:08
Show Gist options
  • Save nathaniel-may/05f2663b27a9f784a4b7e2724efae37c to your computer and use it in GitHub Desktop.
Save nathaniel-may/05f2663b27a9f784a4b7e2724efae37c to your computer and use it in GitHub Desktop.
Simple Type Unification Example
{-# LANGUAGE FlexibleContexts, FlexibleInstances, MultiParamTypeClasses, GeneralizedNewtypeDeriving #-}
module Main where
import Control.Monad(when)
import Control.Monad.Except
import Control.Monad.State
import Data.Foldable (traverse_, foldrM)
import Data.Functor (($>))
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromMaybe)
import Data.Set as Set
main :: IO ()
main = do
let input = ListU [ListU [], ListU [IntU 1], ListU []]
let output = runTcM (typecheck input)
print output
data ExprU
= IntU Int
| ListU [ExprU]
deriving (Read, Show, Eq)
data ExprT
= IntT Int
| ListT [ExprT]
deriving (Read, Show, Eq)
data Type
= IntTag
| ListTag Type
| Hole TypeVar
deriving (Read, Show, Eq)
type TypeVar = String
data Unifier = Unifier (UnionFind TypeVar) (Map TypeVar Type)
newtype TcM a = TcM (StateT Unifier (Either String) a)
deriving (Functor, Applicative, Monad, MonadError String, MonadState Unifier)
runTcM :: TcM a -> Either String a
runTcM (TcM x) = evalStateT x (Unifier (UnionFind M.empty) M.empty)
-- top-level bidirectional type checking
typecheck :: (MonadError String m, MonadState Unifier m) => ExprU -> m ExprT
typecheck (IntU x) = pure (IntT x)
typecheck (ListU xs) = do
xs' <- traverse typecheck xs
t <- inferType (ListT xs')
let listt = ListT xs'
checkType t listt $> listt
-- "left to right" type checking
checkType :: (MonadError String m, MonadState Unifier m) => Type -> ExprT -> m ()
checkType IntTag (IntT _) = pure ()
checkType (ListTag t) (ListT xs) = traverse_ (checkType t) xs
checkType (Hole var) expr =
-- TODO should I throw here or ignore with `pure ()`
findType var >>= maybe (throwError $ "type unknown for type variable" <> show var) (`checkType` expr)
-- if it didn't match the above cases, it's a type mismatch
checkType t x = do
t' <- inferType x
when (t /= t') (throwError $ "Type mismatch. Expected " <> show t <> " but got " <> show t')
-- "right to left" type checking, using nested elements to infer types
inferType :: (MonadError String m, MonadState Unifier m) => ExprT -> m Type
inferType (IntT _) = pure IntTag
inferType (ListT xs) = do
-- infer the type of all the elements
ts <- traverse inferType xs
-- we expect all the elements to have the same type as the head element if the list is not empty
t <- case ts of
-- the list is empty so we do not know the type of its elements. assign it a new type variable.
[] -> Hole <$> nextVar
ys@(y : _) -> do
-- unify the types of all the elements
foldrM (\pair _ -> uncurry unify pair) () (ys `zip` tail ys)
-- if y is a type variable, look up if its type is known and return that as the expected type of all elements in the list
case y of
(Hole var) -> fromMaybe y <$> findType var
_ -> pure y
-- this is inferred to be a list whose elements are of type t
pure (ListTag t)
-- get the next available type variable and insert it as ununified
nextVar :: MonadState Unifier m => m TypeVar
nextVar = do
(Unifier uf@(UnionFind table) m) <- get
let used = M.keysSet table `Set.union` Set.fromList (M.elems table)
let var = head $ dropWhile (`Set.member` used) allNames
put (Unifier (ununified var uf) m)
pure var
where
-- all possible type variable names. leverages strings as a list of chars to construct the names
allNames = do
c <- ['a'..'z']
n <- [0..] :: [Integer]
pure (['_', c] <> show n)
findType :: MonadState Unifier m => TypeVar -> m (Maybe Type)
findType var = do
(Unifier uf m) <- get
pure (M.lookup (find var uf) m)
-- takes two types and unifies them or raises an error within the monad
unify :: (MonadError String m, MonadState Unifier m) => Type -> Type -> m ()
unify (Hole x) (Hole y) = do
(Unifier uf m) <- get
let tx = M.lookup (find x uf) m
let ty = M.lookup (find x uf) m
-- look up if either type variable is known. If they don't conflict, update the other to match.
case (tx, ty) of
(Just t0, Just t1) -> when (t0 /= t1) (throwError $ "Type unification failed (1). Cannot unify " <> show t0 <> " and " <> show t1)
(Nothing, Just t1) -> put (Unifier (Main.union x y uf) (M.insert x t1 m))
(Just t0, Nothing) -> put (Unifier (Main.union x y uf) (M.insert y t0 m))
(Nothing, Nothing) -> put (Unifier (Main.union x y uf) m)
-- update if the type variable is unknown, and throw if the types don't match.
unify (Hole x) t1 = do
(Unifier uf m) <- get
let rootx = find x uf
case M.lookup rootx m of
Just t0 -> when (t0 /= t1) (throwError $ "Type unification failed (3). Cannot unify " <> show t0 <> " and " <> show t1)
Nothing -> put (Unifier uf $ M.insert x t1 m)
-- update if the type variable is unknown, and throw if the types don't match.
unify t0 (Hole y) = do
(Unifier uf m) <- get
let rooty = find y uf
case M.lookup rooty m of
Just t1 -> when (t0 /= t1) (throwError $ "Type unification failed (3). Cannot unify " <> show t0 <> " and " <> show t1)
Nothing -> put (Unifier uf $ M.insert y t0 m)
-- polymorphic types unify their inner types
unify (ListTag x) (ListTag y) = unify x y
-- if both types are known, throw if they don't match
unify t0 t1 = when (t0 /= t1) (throwError $ "Type unification failed (2). Cannot unify " <> show t0 <> " and " <> show t1)
-- <k, v> = <child, parent>
newtype UnionFind a = UnionFind (Map a a)
ununified :: Ord a => a -> UnionFind a -> UnionFind a
ununified x (UnionFind m) = UnionFind (M.insert x x m)
-- unoptimized
union :: Ord a => a -> a -> UnionFind a -> UnionFind a
union x y uf@(UnionFind m) = if x' == y' then uf else UnionFind (M.insert x' y' m) where
x' = find x uf
y' = find y uf
find :: Ord a => a -> UnionFind a -> a
find x uf = if x == x' then x else find x uf where
x' = parent x uf
parent :: Ord a => a -> UnionFind a -> a
parent x (UnionFind m) = fromMaybe x (M.lookup x m)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment