Created
February 13, 2018 20:01
-
-
Save Garciat/221d38117f7346613fa0acfecd6e8efb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TupleSections #-} | |
import Control.Monad.State (State, evalState, get, gets, modify, put) | |
import Data.Function (on) | |
import Data.List (unionBy) | |
import Data.Map (Map) | |
import Data.Maybe (fromMaybe) | |
import qualified Data.Map as Map | |
type Name = String | |
data VLiteral | |
= LInteger Integer | |
| LBoolean Bool | |
deriving (Show, Eq) | |
data Expr | |
= ELam Name Expr | |
| EApp Expr Expr | |
| ELit VLiteral | |
| EVar Name | |
deriving (Show, Eq) | |
intLit = ELit . LInteger | |
data TLiteral | |
= TInt | |
| TBool | |
deriving (Eq) | |
instance Show TLiteral where | |
show TInt = "Int" | |
show TBool = "Bool" | |
data Type | |
= TLam Type Type | |
| TVar Name | |
| TLit TLiteral | |
deriving (Eq) | |
instance Show Type where | |
show (TLam a@TLam{} b) = "(" ++ show a ++ ")" ++ " -> " ++ show b | |
show (TLam a b) = show a ++ " -> " ++ show b | |
show (TLit lit) = show lit | |
show (TVar name) = name | |
intTy = TLit TInt | |
boolTy = TLit TBool | |
data Constraint | |
= Constraint Type Type | |
cmap f (Constraint a b) = Constraint (f a) (f b) | |
instance Show Constraint where | |
show (Constraint a b) = show a ++ " <=> " ++ show b | |
data ConState | |
= ConState | |
{ conFreshId :: Int | |
, conEnv :: Map Name Type | |
} deriving (Show) | |
type Con a = State ConState a | |
freshTVar :: Con Type | |
freshTVar = do | |
i <- gets conFreshId | |
modify (\s -> s { conFreshId = 1 + i }) | |
pure $ TVar ("a" ++ show i) | |
lookupType :: Name -> Con Type | |
lookupType name = gets (fromMaybe notFound . Map.lookup name . conEnv) | |
where | |
notFound = error ("name not found: " ++ name) | |
insertType :: Name -> Type -> Con () | |
insertType name ty = do | |
env <- gets conEnv | |
modify (\s -> s { conEnv = Map.insert name ty env }) | |
literalType :: VLiteral -> Type | |
literalType (LInteger _) = intTy | |
literalType (LBoolean _) = boolTy | |
scoped :: State s a -> State s a | |
scoped action = do | |
s <- get | |
action <* put s | |
constrain :: Expr -> Con (Type, [Constraint]) | |
constrain (ELit lit) = pure (literalType lit, []) | |
constrain (EVar name) = (, []) <$> lookupType name | |
constrain (ELam var body) = do | |
tvar <- freshTVar | |
(tbody, ctrs) <- scoped (insertType var tvar *> constrain body) | |
pure (TLam tvar tbody, ctrs) | |
constrain (EApp left right) = do | |
(lty, lctrs) <- constrain left | |
(rty, rctrs) <- constrain right | |
tvar <- freshTVar | |
pure (tvar, lctrs ++ rctrs ++ [Constraint lty (TLam rty tvar)]) | |
data Substitution | |
= Substitution Name Type | |
subMap f (Substitution name ty) = Substitution name (f ty) | |
subName (Substitution name _) = name | |
instance Show Substitution where | |
show (Substitution name ty) = name ++ " => " ++ show ty | |
applySub :: Substitution -> Type -> Type | |
applySub (Substitution target ty) = go | |
where | |
go (TLam head body) = TLam (go head) (go body) | |
go (TVar name) = if name == target then ty else TVar name | |
go stuff = stuff | |
applySubs :: [Substitution] -> Type -> Type | |
applySubs subs subject = foldr applySub subject subs | |
combine :: [Substitution] -> [Substitution] -> [Substitution] | |
combine left right = right' `merge` left | |
where | |
right' = map (subMap (applySubs left)) right | |
merge = unionBy ((==) `on` subName) | |
unify :: Type -> Type -> Either String [Substitution] | |
unify tyA tyB | |
| tyA == tyB = Right [] | |
| otherwise = | |
case (tyA, tyB) of | |
(TLam argA bodyA, TLam argB bodyB) -> do | |
argSub <- unify argA argB | |
bodySub <- unify (applySubs argSub bodyA) | |
(applySubs argSub bodyB) | |
Right $ combine bodySub argSub | |
(TVar name, ty) -> Right $ [Substitution name ty] | |
(ty, TVar name) -> Right $ [Substitution name ty] | |
otherwise -> Left ("Cannot unify `" ++ show tyA ++ "` with `" ++ show tyB ++ "`") | |
solve :: [Constraint] -> Either String [Substitution] | |
solve = go [] | |
where | |
go final [] = Right final | |
go final (Constraint a b:cs) = do | |
sub <- unify a b | |
go (combine sub final) | |
(map (cmap $ applySubs sub) cs) | |
stdLib = | |
Map.fromList | |
[ ("add", TLam intTy (TLam intTy intTy)) | |
, ("gt", TLam intTy (TLam intTy boolTy)) | |
, ("if", TLam boolTy (TLam (TVar "a") (TLam (TVar "a") (TVar "a")))) | |
, ("fix", TLam (TLam (TVar "b") (TVar "b")) (TVar "b")) | |
] | |
emptyConState = ConState 0 stdLib | |
infer :: Expr -> Either String Type | |
infer expr = | |
let (ty, ctrs) = evalState (constrain expr) emptyConState | |
in (\cs -> applySubs cs ty) <$> solve ctrs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment