Skip to content

Instantly share code, notes, and snippets.

@lynn
Last active July 2, 2024 09:52
Show Gist options
  • Save lynn/68d76d70b01ed2c3e0ef3daab6d64a0e to your computer and use it in GitHub Desktop.
Save lynn/68d76d70b01ed2c3e0ef3daab6d64a0e to your computer and use it in GitHub Desktop.
A Haskell implementation of Hindley–Milner (including unification) using STRefs
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
import Control.Monad
import Control.Monad.Except
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.ST
import Control.Monad.ST.Class
import Control.Monad.State
import Data.Foldable
import Data.List
import qualified Data.Map as M
import Data.Map (Map)
import Data.STRef
------------------------------------------------------------------------------------
-- Helper functions for MonadST.
readST :: MonadST m => STRef (World m) a -> m a
readST r = liftST (readSTRef r)
writeST :: MonadST m => STRef (World m) a -> a -> m ()
writeST r a = liftST (writeSTRef r a)
------------------------------------------------------------------------------------
-- Expressions in our "programming language".
type Name = String
data Literal = LInt Integer | LBool Bool deriving (Eq, Ord, Show)
data Expr
= Lit Literal
| Var Name
| App Expr Expr
| Lam Name Expr
| Let Name Expr Expr
deriving (Eq, Ord, Show)
------------------------------------------------------------------------------------
-- The types of expressions. The type parameter `i` is the type we use to represent type variables.
--
-- * A fully polymorphic type such as `forall α. α → α` is represented by `Type PVar`.
--
-- * In unification, `i` is `UVar s`, the type of unification/monomorphic type variables.
-- These are STRefs containing `Nothing` (unbound) or `Just t` (bound, to some other `Type (UVar s)`).
--
-- * The outcome of polymorphizing a type is `Type (EitherVar s)`, as it might contain
-- a mix of polymorphic and monomorphic type variables.
--
-- * A `Type String` is one where all the type variables have names, for printing!
--
data Type i
= TInt | TBool | TVar i | Type i :-> Type i
deriving (Eq, Ord, Show, Functor, Foldable, Traversable)
-- Pretty-printing for `Type String`.
prettyType :: Type String -> String
prettyType = go False where
go _ TInt = "Int"
go _ TBool = "Bool"
go _ (TVar s) = s
go parens (a :-> b) = (if parens then parenthesize else id) (concat [go True a, " -> ", go False b])
parenthesize x = "(" ++ x ++ ")"
-- Here's all that can go wrong in the realm of Hindley-Milner:
data InferenceError = UnknownVariable | Couldn'tUnify | InfiniteType deriving (Eq, Ord, Show)
------------------------------------------------------------------------------------
-- Unification variables are STRefs that we modify during unification.
-- They can hold Nothing (indicating fresh, unbound variables) or Just a type with
-- other unification variables in it (indicating it was unified with this type).
--
-- "I call them unification variables because in the algorithm the way they get
-- filled in with a concrete expression is by unification.
-- & I called them monomorphic type variables because they're unknown but they
-- can only be instantiated a single way"
-- - dianne
--
-- `s` is just the STRef state thread. Sometimes it shows up as `World m`.
newtype UVar s = UVar (STRef s (Maybe (Type (UVar s)))) deriving (Eq)
-- Helper function: make a fresh, unbound UVar.
newUVar :: MonadST m => m (UVar (World m))
newUVar = liftST (UVar <$> newSTRef Nothing)
-- Polymorphic type variables are far simpler! They're just Ints.
type PVar = Int
-- An EitherVar is exactly what it sounds like. As mentioned above, it is used in the outcome
-- of `polymorphize`. For example, in the expression
--
-- (\x -> let f = (\a -> x) in f),
--
-- the result of `polymorphize [uvar_of_x] type_of_f` would be
--
-- TVar (Polymorphic 0) :-> TVar (Monomorphic uvar_of_x).
--
data EitherVar s = Monomorphic (UVar s) | Polymorphic PVar deriving (Eq)
-- The type environment is a mapping of variable names to types.
type TypeEnv s = Map Name (Type (EitherVar s))
------------------------------------------------------------------------------------
-- Use the Foldable instance to decide whether a type variable occurs in a type.
occurs :: MonadST m => UVar (World m) -> Type (UVar (World m)) -> m Bool
occurs u = fmap (any (== u)) . expandBoundTypes
-- Unify two types.
unify :: (MonadST m, MonadError InferenceError m)
=> Type (UVar (World m)) -> Type (UVar (World m)) -> m ()
unify TInt TInt = return ()
unify TBool TBool = return ()
unify (a :-> b) (c :-> d) = unify a c >> unify b d
unify (TVar u@(UVar ref)) t =
readST ref >>= \case
Nothing -> occurs u t >>= \case
True -> throwError InfiniteType
False -> writeST ref (Just t)
Just t' -> unify t t'
unify t (TVar u) = unify (TVar u) t
unify _ _ = throwError Couldn'tUnify
------------------------------------------------------------------------------------
-- Expand the filled UVars in a type. We replace every `TVar (UVar ref)` where `ref`
-- contains `Just t` by `t` itself.
expandBoundTypes :: MonadST m => Type (UVar (World m)) -> m (Type (UVar (World m)))
expandBoundTypes t@(TVar (UVar ref)) =
readST ref >>= \case Just t' -> expandBoundTypes t'
Nothing -> return t
-- The remaining cases are uninteresting:
expandBoundTypes (a :-> b) = do t1 <- expandBoundTypes a
t2 <- expandBoundTypes b
return (t1 :-> t2)
expandBoundTypes TInt = return TInt
expandBoundTypes TBool = return TBool
------------------------------------------------------------------------------------
-- `relabeler mk` is a function that relabels equal `a`s into equal `b`s, keeping
-- track of a cache of labels :: [(a, b)]. If the given `a` is not found in the cache,
-- it will be freshly created using the monadic action `mk`.
--
-- Then, of course, `traverse (relabeler f)` can be used to relabel any tree-like structure:
--
-- x 0 <- generated by mk
-- / \ / \
-- y x -------> 1 0 <- cache hit!
-- / \ / \
-- z y 2 1 <- generated by mk
--
relabeler :: (Eq a, MonadState [(a, b)] m) => m b -> (a -> m b)
relabeler mk a = do
cache <- get
case lookup a cache of
Just b -> return b
Nothing -> do b <- mk
modify ((a,b):) -- Put this pair in the cache.
return b
-- (In practice, we'll only want to relabel stuff conditionally, so we won't *quite*
-- write `traverse (relabeler f)`. In `monomorphize` and `polymorphize` below, we write
-- `traverse blah`, and `blah` either calls `relabeler` or something else.)
------------------------------------------------------------------------------------
-- Monomorphize all the polymorphic type variables in a type; i.e. instantiate them.
-- We do this by traversing the type structure with a relabeler that creates a fresh
-- unification (monomorphic) variable for every new PVar it encounters.
monomorphize :: MonadST m => Type (EitherVar (World m)) -> m (Type (UVar (World m)))
monomorphize t = evalStateT (traverse monomorphizeVar t) [] where
monomorphizeVar :: (MonadState [(PVar, UVar (World m))] m, MonadST m)
=> EitherVar (World m) -> m (UVar (World m))
monomorphizeVar (Monomorphic uvar) = return uvar
monomorphizeVar (Polymorphic pvar) = relabeler newUVar pvar
-- Generalize over type variables not already in the type environment, making them polymorphic.
--
-- This is the raison d'être of `let`! It makes `let id = (\x -> x) in id id 3` typecheck:
-- we want `id` to have a polymorphic type, which we obtain by generalizing over the
-- unification variable that we assigned to `x`.
--
-- But, for example, if we're handling `(\y -> let q = y in q) 2`, we don't want to
-- generalize over the UVar for `y`, or we'll incorrectly infer the type of that expression
-- as `a` instead of `Int`. To deal with this case, `polymorphize` accepts an `inEnv`
-- argument, and leaves all the UVars in that list alone.
--
-- I've used constraints on `uvar` and `eithervar` as makeshift aliases. See you all in hell
--
polymorphize :: forall m uvar eithervar.
(MonadST m, uvar ~ UVar (World m), eithervar ~ EitherVar (World m))
=> [uvar] -> Type uvar -> m (Type eithervar)
polymorphize inEnv t = do t' <- expandBoundTypes t
evalStateT (traverse generalizeVar t') [] where
-- Generalize a UVar (if it isn't in the given environment) by relabeling it with a PVar (Int) value.
-- (We use `gets length` as a label-maker, which means that when our cache is [(…,2),(…,1),(…,0)]
-- the next value used will be `3`, which is precisely what we want.)
generalizeVar :: uvar -> StateT [(uvar, PVar)] m eithervar
generalizeVar uvar | uvar `elem` inEnv = return (Monomorphic uvar)
| otherwise = Polymorphic <$> relabeler (gets length) uvar
-- Return the free UVars in a type environment.
freeVarsInEnv :: MonadST m => TypeEnv (World m) -> m [UVar (World m)]
freeVarsInEnv env = do
-- Reap all the UVars, which may not be expanded...
let evars = concatMap toList (M.elems env)
let uvars = [uvar | Monomorphic uvar <- evars]
-- Expand them all to full-blown types with only free UVars in them,
expandedTypes <- mapM (expandBoundTypes . TVar) uvars
-- and reap once more.
return (concatMap toList expandedTypes)
freeVarsInExpandedTerm :: Type (EitherVar s) -> [UVar s]
freeVarsInExpandedTerm t = [uvar | Monomorphic uvar <- toList t]
-- An infinite list of type names: ["a", "b", ..., "z", "a1", "b1", ...]
typeNames :: [String]
typeNames = [letter : suf | suf <- "" : map show [1..], letter <- ['a'..'z']]
-- Assign names to a polymorphic type. Since PVars are just Ints, we can index into the above list.
nameType :: Type PVar -> Type String
nameType = fmap (typeNames !!)
type Typechecking m = ( -- Type checking/inference will happen in a monad that:
MonadST m, -- * can modify the STRefs used for unification,
MonadReader (TypeEnv (World m)) m, -- * has access to the type environment, and
MonadError InferenceError m) -- * can throw InferenceErrors.
-- Here we go!
typeof :: Typechecking m => Expr -> m (Type (UVar (World m)))
typeof (Lit (LInt _)) = return TInt
typeof (Lit (LBool _)) = return TBool
typeof (Var v) = -- Look up v in the environment.
asks (M.lookup v) >>= \case Just t -> monomorphize t
Nothing -> throwError UnknownVariable
typeof (App e1 e2) = do
t1 <- typeof e1
t2 <- typeof e2
resultType <- TVar <$> newUVar
unify t1 (t2 :-> resultType)
return resultType
typeof (Lam v e) = do
vType <- TVar <$> newUVar
(vType :->) <$> local (M.insert v (Monomorphic <$> vType)) (typeof e)
typeof (Let v e1 e2) = do
t1 <- typeof e1
freeVars <- freeVarsInEnv =<< ask
p1 <- polymorphize freeVars t1
local (M.insert v p1) (typeof e2)
-- Here's `run` for the actual monad transformer stack we use.
-- We pass in the empty map as a type environment; we could also predefine a "prelude" instead.
runTypechecking :: (forall s. ReaderT (TypeEnv s) (ExceptT InferenceError (ST s)) a) -> Either InferenceError a
runTypechecking x = runST $ runExceptT $ runReaderT x M.empty
-- Helper function to test our implementation.
-- We typecheck x, polymorphize the result, assign names to the PVars and pretty-print the result.
test :: Expr -> Either InferenceError String
test x = fmap (prettyType . nameType) $ runTypechecking action
where action :: (forall s. ReaderT (TypeEnv s) (ExceptT InferenceError (ST s)) (Type PVar))
action = do { t <- typeof x; t' <- polymorphize [] t; return (asPolymorphicType t') }
asPolymorphicType = fmap (\(Polymorphic var) -> var)
-- The omega combinator.
-- Type inference should fail with an InfiniteType error.
ω :: Expr
ω = Lam "x" (App (Var "x") (Var "x"))
-- Non-unifiable: applying an int to a bool.
-- Type inference should fail with a Couldn'tUnify error.
nonunifiable :: Expr
nonunifiable = App (Lit$LInt 3) (Lit$LBool False)
-- Unknown variable.
-- Type inference should fail with, well, an UnknownVariable error.
unknownVariable :: Expr
unknownVariable = Lam "x" (Var "y")
-- Function composition: `compose f g` applies f, then g.
-- Type inference should give `(a -> b) -> (b -> c) -> a -> c`.
composition :: Expr
composition = Lam "f" (Lam "g" (Lam "x" (App (Var "g") (App (Var "f") (Var "x")))))
-- The examples from the `polymorphize` comment.
-- This is `let id = (\x -> x) in id id 3`, and it should have type `Int` (not fail).
poly1 :: Expr
poly1 = Let "id" (Lam "x" $ Var "x") (App (Var "id") $ App (Var "id") $ Lit (LInt 3))
-- This is `(\y -> let q = y in q) 2`, and it should have type `Int` (not `a`).
poly2 :: Expr
poly2 = App f (Lit $ LInt 2) where f = Lam "y" (Let "q" (Var "y") (Var "q"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment