Created
March 25, 2023 23:08
-
-
Save nathaniel-may/05f2663b27a9f784a4b7e2724efae37c to your computer and use it in GitHub Desktop.
Simple Type Unification Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts, FlexibleInstances, MultiParamTypeClasses, GeneralizedNewtypeDeriving #-} | |
module Main where | |
import Control.Monad(when) | |
import Control.Monad.Except | |
import Control.Monad.State | |
import Data.Foldable (traverse_, foldrM) | |
import Data.Functor (($>)) | |
import Data.Map (Map) | |
import qualified Data.Map as M | |
import Data.Maybe (fromMaybe) | |
import Data.Set as Set | |
main :: IO () | |
main = do | |
let input = ListU [ListU [], ListU [IntU 1], ListU []] | |
let output = runTcM (typecheck input) | |
print output | |
data ExprU | |
= IntU Int | |
| ListU [ExprU] | |
deriving (Read, Show, Eq) | |
data ExprT | |
= IntT Int | |
| ListT [ExprT] | |
deriving (Read, Show, Eq) | |
data Type | |
= IntTag | |
| ListTag Type | |
| Hole TypeVar | |
deriving (Read, Show, Eq) | |
type TypeVar = String | |
data Unifier = Unifier (UnionFind TypeVar) (Map TypeVar Type) | |
newtype TcM a = TcM (StateT Unifier (Either String) a) | |
deriving (Functor, Applicative, Monad, MonadError String, MonadState Unifier) | |
runTcM :: TcM a -> Either String a | |
runTcM (TcM x) = evalStateT x (Unifier (UnionFind M.empty) M.empty) | |
-- top-level bidirectional type checking | |
typecheck :: (MonadError String m, MonadState Unifier m) => ExprU -> m ExprT | |
typecheck (IntU x) = pure (IntT x) | |
typecheck (ListU xs) = do | |
xs' <- traverse typecheck xs | |
t <- inferType (ListT xs') | |
let listt = ListT xs' | |
checkType t listt $> listt | |
-- "left to right" type checking | |
checkType :: (MonadError String m, MonadState Unifier m) => Type -> ExprT -> m () | |
checkType IntTag (IntT _) = pure () | |
checkType (ListTag t) (ListT xs) = traverse_ (checkType t) xs | |
checkType (Hole var) expr = | |
-- TODO should I throw here or ignore with `pure ()` | |
findType var >>= maybe (throwError $ "type unknown for type variable" <> show var) (`checkType` expr) | |
-- if it didn't match the above cases, it's a type mismatch | |
checkType t x = do | |
t' <- inferType x | |
when (t /= t') (throwError $ "Type mismatch. Expected " <> show t <> " but got " <> show t') | |
-- "right to left" type checking, using nested elements to infer types | |
inferType :: (MonadError String m, MonadState Unifier m) => ExprT -> m Type | |
inferType (IntT _) = pure IntTag | |
inferType (ListT xs) = do | |
-- infer the type of all the elements | |
ts <- traverse inferType xs | |
-- we expect all the elements to have the same type as the head element if the list is not empty | |
t <- case ts of | |
-- the list is empty so we do not know the type of its elements. assign it a new type variable. | |
[] -> Hole <$> nextVar | |
ys@(y : _) -> do | |
-- unify the types of all the elements | |
foldrM (\pair _ -> uncurry unify pair) () (ys `zip` tail ys) | |
-- if y is a type variable, look up if its type is known and return that as the expected type of all elements in the list | |
case y of | |
(Hole var) -> fromMaybe y <$> findType var | |
_ -> pure y | |
-- this is inferred to be a list whose elements are of type t | |
pure (ListTag t) | |
-- get the next available type variable and insert it as ununified | |
nextVar :: MonadState Unifier m => m TypeVar | |
nextVar = do | |
(Unifier uf@(UnionFind table) m) <- get | |
let used = M.keysSet table `Set.union` Set.fromList (M.elems table) | |
let var = head $ dropWhile (`Set.member` used) allNames | |
put (Unifier (ununified var uf) m) | |
pure var | |
where | |
-- all possible type variable names. leverages strings as a list of chars to construct the names | |
allNames = do | |
c <- ['a'..'z'] | |
n <- [0..] :: [Integer] | |
pure (['_', c] <> show n) | |
findType :: MonadState Unifier m => TypeVar -> m (Maybe Type) | |
findType var = do | |
(Unifier uf m) <- get | |
pure (M.lookup (find var uf) m) | |
-- takes two types and unifies them or raises an error within the monad | |
unify :: (MonadError String m, MonadState Unifier m) => Type -> Type -> m () | |
unify (Hole x) (Hole y) = do | |
(Unifier uf m) <- get | |
let tx = M.lookup (find x uf) m | |
let ty = M.lookup (find x uf) m | |
-- look up if either type variable is known. If they don't conflict, update the other to match. | |
case (tx, ty) of | |
(Just t0, Just t1) -> when (t0 /= t1) (throwError $ "Type unification failed (1). Cannot unify " <> show t0 <> " and " <> show t1) | |
(Nothing, Just t1) -> put (Unifier (Main.union x y uf) (M.insert x t1 m)) | |
(Just t0, Nothing) -> put (Unifier (Main.union x y uf) (M.insert y t0 m)) | |
(Nothing, Nothing) -> put (Unifier (Main.union x y uf) m) | |
-- update if the type variable is unknown, and throw if the types don't match. | |
unify (Hole x) t1 = do | |
(Unifier uf m) <- get | |
let rootx = find x uf | |
case M.lookup rootx m of | |
Just t0 -> when (t0 /= t1) (throwError $ "Type unification failed (3). Cannot unify " <> show t0 <> " and " <> show t1) | |
Nothing -> put (Unifier uf $ M.insert x t1 m) | |
-- update if the type variable is unknown, and throw if the types don't match. | |
unify t0 (Hole y) = do | |
(Unifier uf m) <- get | |
let rooty = find y uf | |
case M.lookup rooty m of | |
Just t1 -> when (t0 /= t1) (throwError $ "Type unification failed (3). Cannot unify " <> show t0 <> " and " <> show t1) | |
Nothing -> put (Unifier uf $ M.insert y t0 m) | |
-- polymorphic types unify their inner types | |
unify (ListTag x) (ListTag y) = unify x y | |
-- if both types are known, throw if they don't match | |
unify t0 t1 = when (t0 /= t1) (throwError $ "Type unification failed (2). Cannot unify " <> show t0 <> " and " <> show t1) | |
-- <k, v> = <child, parent> | |
newtype UnionFind a = UnionFind (Map a a) | |
ununified :: Ord a => a -> UnionFind a -> UnionFind a | |
ununified x (UnionFind m) = UnionFind (M.insert x x m) | |
-- unoptimized | |
union :: Ord a => a -> a -> UnionFind a -> UnionFind a | |
union x y uf@(UnionFind m) = if x' == y' then uf else UnionFind (M.insert x' y' m) where | |
x' = find x uf | |
y' = find y uf | |
find :: Ord a => a -> UnionFind a -> a | |
find x uf = if x == x' then x else find x uf where | |
x' = parent x uf | |
parent :: Ord a => a -> UnionFind a -> a | |
parent x (UnionFind m) = fromMaybe x (M.lookup x m) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment