Skip to content

Instantly share code, notes, and snippets.

@bspaans
Created July 18, 2010 14:01
Show Gist options
  • Save bspaans/480418 to your computer and use it in GitHub Desktop.
Save bspaans/480418 to your computer and use it in GitHub Desktop.
-- | A simply typed lambda calculus with type reconstruction
-- extended with integers and let bindings.
-- Using the new Evaluator monad to do the variable generation
-- for us.
-- No alpha reduction.
--
module Lambda where
import Evaluator4
import Control.Applicative
import Control.Arrow
import Control.Monad.Writer
import Control.Monad.Reader
import qualified Data.Map as M
import Data.Char
import Control.Monad ((>=>), liftM2)
data Expr =
Value Int
| Variable String
| Lambda String Type Expr -- (only) lambdas can have a type annotation
| App Expr Expr
| Let String Expr Expr -- let s = e1 in e2
| Add Expr Expr
| Sub Expr Expr
| Mul Expr Expr
| Div Expr Expr
deriving Eq
instance Num Expr where
fromInteger = Value . fromInteger
(+) = Add
(-) = Sub
(*) = Mul
abs = undefined
signum = undefined
instance Show Expr where
show (Value i) = show i
show (Variable s) = s
show (Lambda s _ e) = concat ["\\", s, " . ", show e]
show (App l@(Lambda _ _ _) e2) = concat ["(", show l, ") ", show e2]
show (App e1 e2) = concat [show e1, " ", show e2]
show (Let s e1 e2) = concat ["let ", s, " = ", show e1, "\nin ", show e2]
show (Add e1 e2) = concat [show e1, " + ", show e2]
show (Sub e1 e2) = concat [show e1, " - ", show e2]
show (Mul e1 e2) = concat [show e1, " * ", show e2]
show (Div e1 e2) = concat [show e1, " / ", show e2]
data Type = TInt
| TVar String -- Type variables
| TFun Type Type
deriving Eq
instance Show Type where
show (TFun t1@(TFun _ _) t2) = "(" ++ show t1 ++ ") -> " ++ show t2
show (TFun t1 t2) = show t1 ++ " -> " ++ show t2
show (TVar s) = s
show TInt = "int"
a = TVar "a"
b = TVar "b"
c = TVar "c"
d = TVar "d"
x = Variable "x"
y = Variable "y"
z = Variable "z"
int = TInt
t1 ~> t2 = TFun t1 t2
infixr 9 ~>
type Substitution = (Type, Type) -- [t1 -> t2]
type Substitutions = [Substitution]
type Constraint = (Type, Type) -- t1 = t2
type Constraints = [Constraint]
-- * Constraints
--
-- Hindley-Milner type-inference works as follows:
--
-- 1. Take untyped lambda calculus and introduce
-- a new fresh variable for every lambda abstraction.
-- 2. Walk the AST as if we were type checking, but
-- instead of type checking, we generate constraints
-- (and still return types as well)
-- For example: when we have the application:
--
-- t1:X t2:Y
--
-- We still return the codomain of X as the type of
-- the whole term, but instead of applying the
-- type Y to X, we generate a new free variable Z
-- and add a type constraint {X = Y -> Z}
--
-- We create a constraint for every function application:
-- t1 t2 generates a constraint
-- typoOf(t1) = typeOf (t2) -> X
-- where X is a free variable.
--
-- For binary functions `t1 + t2`, `t1 - t2`, etc. which
-- take two terms of type int, we generate two constraints:
-- int = typeOf(t1)
-- int = typeOf(t2)
--
-- 3. Use unification to convert the constraints
-- to a number of type substitions.
--
-- 4. Use the substitutions to reconstruct the type
-- returned in step 2.
--
-- This implementation combines steps 2, 3 and 4.
--
type TypeEval = Eval (Env Type) [String] (Type, Constraints)
inferType :: Expr -> TypeEval
inferType (Value _) = succeeds (int, [])
inferType (Variable v) = flip (,) [] <$> envLookup v
inferType (Add e1 e2) = binOp e1 e2
inferType (Sub e1 e2) = binOp e1 e2
inferType (Mul e1 e2) = binOp e1 e2
inferType (Div e1 e2) = binOp e1 e2
inferType (Lambda v t e) = do
(t', c) <- local (M.insert v t) (inferType e)
reconstruct (t ~> t') c
inferType (App e1 e2) = do
(t1, c1) <- inferType e1
(t2, c2) <- inferType e2
v <- TVar <$> newFreeVar
reconstruct v (c1 ++ c2 ++ [(t2 ~> v, t1)])
inferType (Let s e1 e2) = do
(t1, c1) <- inferType e1
(t2, c2) <- local (M.insert s t1) (inferType e2)
reconstruct t2 (c1 ++ c2)
binOp :: Expr -> Expr -> TypeEval
binOp e1 e2 = do
(t1, c1) <- inferType e1
(t2, c2) <- inferType e2
let c = c1 ++ c2 ++ [(int, t) | t <- [t1, t2], t /= int]
reconstruct t1 c >> reconstruct t2 c >> succeeds (int, c)
reconstruct :: Type -> Constraints -> TypeEval
reconstruct ty = unify >=> \cs -> succeeds (foldr substitute ty cs, cs)
-- * Unification
-- Convert constraints into substitutions
-- | Tries to unify the constraints; producing
-- a set of substitutions.
--
unify :: Constraints -> Eval env [String] Substitutions
unify [] = succeeds []
unify ((c1, c2):cs) = if c1 == c2 then unify cs else unify' c1 c2 cs
unify' :: Type -> Type -> Constraints -> Eval env [String] Substitutions
unify' v@(TVar _) c cs = occursCheck v c cs
unify' c v@(TVar _) cs = occursCheck v c cs
unify' (TFun d1 r1) (TFun d2 r2) cs = unify (cs ++ [(d1, d2), (r1, r2)])
unify' v1 v2 _ = failsWith $ "Failed to unify constraints. " ++ show v1 ++ " = " ++ show v2
-- | Occurs check:
-- In the constraint X = t2, check whether X is
-- not in the free variables of t2. In other
-- words: check if the types are not recursive.
--
occursCheck :: Type -> Type -> Constraints -> Eval env [String] Substitutions
occursCheck v@(TVar v') c cs =
if elem v' (freeVar c)
then failsWith ("Failed occurs check: " ++ v' ++ " = " ++ show c)
else unify (substituteC (v, c) cs) >>= \ss -> succeeds ((v, c):ss)
freeVar :: Type -> [String]
freeVar TInt = []
freeVar (TFun t1 t2) = freeVar t1 ++ freeVar t2
freeVar (TVar a) = [a]
-- * Type Substitutions
-- Substitute type variables.
--
substituteC :: (Type, Type) -> Constraints -> Constraints
substituteC _ [] = []
substituteC s ((c1, c2):cs) = (substitute s c1, substitute s c2) : substituteC s cs
substitute :: (Type, Type) -> Type -> Type
substitute s TInt = TInt
substitute s (TFun d r) = substitute s d ~> substitute s r
substitute (v, t) v' = if v == v' then t else v'
-- * Evaluation
-- Again, nothing changes here (because all the types can
-- be erased; leaving the untyped lambda calculus behind)
--
type EV e = Eval (Env Expr) [String] e
evalExpr :: Expr -> EV Expr
evalExpr v@(Value i) = succeeds v
evalExpr (Variable s) = envLookup s
evalExpr l@(Lambda _ _ _) = succeeds l
evalExpr (App e1 e2) = evalExpr e1 >>= \(Lambda var _ e) -> updatedEnv var e2 e
evalExpr (Let v e1 e2) = updatedEnv v e1 e2
evalExpr (Add e1 e2) = binOp' (+) e1 e2
evalExpr (Sub e1 e2) = binOp' (-) e1 e2
evalExpr (Mul e1 e2) = binOp' (*) e1 e2
evalExpr (Div e1 e2) = getValue e1 >>= \v -> getValue e2 >>= f v
where f _ 0 = failsWith "Error: division by zero"
f v w = succeeds (Value $ div v w)
binOp' :: (Int -> Int -> Int) -> Expr -> Expr -> EV Expr
binOp' op e1 e2 = ((Value . ) . op) <$> getValue e1 <*> getValue e2
getValue :: Expr -> EV Int
getValue e = evalExpr e >>= \(Value i) -> return i
updatedEnv :: String -> Expr -> Expr -> EV Expr
updatedEnv var e1 e2 = do
evalExpr e1 >>= flip local (evalExpr e2) . M.insert var
eval :: Expr -> (Maybe (Expr, Type), [String])
eval expr = case evalEval (inferType expr) M.empty of
(Nothing, msg) -> (Nothing, msg)
(Just (ty,_), msg) -> case evalEval (evalExpr expr) M.empty of
(Nothing, msg') -> (Nothing, msg ++ msg')
(Just ex, msg') -> (Just (ex, ty), msg ++ msg')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment