Last active
July 2, 2024 09:52
-
-
Save lynn/68d76d70b01ed2c3e0ef3daab6d64a0e to your computer and use it in GitHub Desktop.
A Haskell implementation of Hindley–Milner (including unification) using STRefs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE ConstraintKinds #-} | |
{-# LANGUAGE DeriveTraversable #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE TypeFamilies #-} | |
import Control.Monad | |
import Control.Monad.Except | |
import Control.Monad.Identity | |
import Control.Monad.Reader | |
import Control.Monad.ST | |
import Control.Monad.ST.Class | |
import Control.Monad.State | |
import Data.Foldable | |
import Data.List | |
import qualified Data.Map as M | |
import Data.Map (Map) | |
import Data.STRef | |
------------------------------------------------------------------------------------ | |
-- Helper functions for MonadST. | |
readST :: MonadST m => STRef (World m) a -> m a | |
readST r = liftST (readSTRef r) | |
writeST :: MonadST m => STRef (World m) a -> a -> m () | |
writeST r a = liftST (writeSTRef r a) | |
------------------------------------------------------------------------------------ | |
-- Expressions in our "programming language". | |
type Name = String | |
data Literal = LInt Integer | LBool Bool deriving (Eq, Ord, Show) | |
data Expr | |
= Lit Literal | |
| Var Name | |
| App Expr Expr | |
| Lam Name Expr | |
| Let Name Expr Expr | |
deriving (Eq, Ord, Show) | |
------------------------------------------------------------------------------------ | |
-- The types of expressions. The type parameter `i` is the type we use to represent type variables. | |
-- | |
-- * A fully polymorphic type such as `forall α. α → α` is represented by `Type PVar`. | |
-- | |
-- * In unification, `i` is `UVar s`, the type of unification/monomorphic type variables. | |
-- These are STRefs containing `Nothing` (unbound) or `Just t` (bound, to some other `Type (UVar s)`). | |
-- | |
-- * The outcome of polymorphizing a type is `Type (EitherVar s)`, as it might contain | |
-- a mix of polymorphic and monomorphic type variables. | |
-- | |
-- * A `Type String` is one where all the type variables have names, for printing! | |
-- | |
data Type i | |
= TInt | TBool | TVar i | Type i :-> Type i | |
deriving (Eq, Ord, Show, Functor, Foldable, Traversable) | |
-- Pretty-printing for `Type String`. | |
prettyType :: Type String -> String | |
prettyType = go False where | |
go _ TInt = "Int" | |
go _ TBool = "Bool" | |
go _ (TVar s) = s | |
go parens (a :-> b) = (if parens then parenthesize else id) (concat [go True a, " -> ", go False b]) | |
parenthesize x = "(" ++ x ++ ")" | |
-- Here's all that can go wrong in the realm of Hindley-Milner: | |
data InferenceError = UnknownVariable | Couldn'tUnify | InfiniteType deriving (Eq, Ord, Show) | |
------------------------------------------------------------------------------------ | |
-- Unification variables are STRefs that we modify during unification. | |
-- They can hold Nothing (indicating fresh, unbound variables) or Just a type with | |
-- other unification variables in it (indicating it was unified with this type). | |
-- | |
-- "I call them unification variables because in the algorithm the way they get | |
-- filled in with a concrete expression is by unification. | |
-- & I called them monomorphic type variables because they're unknown but they | |
-- can only be instantiated a single way" | |
-- - dianne | |
-- | |
-- `s` is just the STRef state thread. Sometimes it shows up as `World m`. | |
newtype UVar s = UVar (STRef s (Maybe (Type (UVar s)))) deriving (Eq) | |
-- Helper function: make a fresh, unbound UVar. | |
newUVar :: MonadST m => m (UVar (World m)) | |
newUVar = liftST (UVar <$> newSTRef Nothing) | |
-- Polymorphic type variables are far simpler! They're just Ints. | |
type PVar = Int | |
-- An EitherVar is exactly what it sounds like. As mentioned above, it is used in the outcome | |
-- of `polymorphize`. For example, in the expression | |
-- | |
-- (\x -> let f = (\a -> x) in f), | |
-- | |
-- the result of `polymorphize [uvar_of_x] type_of_f` would be | |
-- | |
-- TVar (Polymorphic 0) :-> TVar (Monomorphic uvar_of_x). | |
-- | |
data EitherVar s = Monomorphic (UVar s) | Polymorphic PVar deriving (Eq) | |
-- The type environment is a mapping of variable names to types. | |
type TypeEnv s = Map Name (Type (EitherVar s)) | |
------------------------------------------------------------------------------------ | |
-- Use the Foldable instance to decide whether a type variable occurs in a type. | |
occurs :: MonadST m => UVar (World m) -> Type (UVar (World m)) -> m Bool | |
occurs u = fmap (any (== u)) . expandBoundTypes | |
-- Unify two types. | |
unify :: (MonadST m, MonadError InferenceError m) | |
=> Type (UVar (World m)) -> Type (UVar (World m)) -> m () | |
unify TInt TInt = return () | |
unify TBool TBool = return () | |
unify (a :-> b) (c :-> d) = unify a c >> unify b d | |
unify (TVar u@(UVar ref)) t = | |
readST ref >>= \case | |
Nothing -> occurs u t >>= \case | |
True -> throwError InfiniteType | |
False -> writeST ref (Just t) | |
Just t' -> unify t t' | |
unify t (TVar u) = unify (TVar u) t | |
unify _ _ = throwError Couldn'tUnify | |
------------------------------------------------------------------------------------ | |
-- Expand the filled UVars in a type. We replace every `TVar (UVar ref)` where `ref` | |
-- contains `Just t` by `t` itself. | |
expandBoundTypes :: MonadST m => Type (UVar (World m)) -> m (Type (UVar (World m))) | |
expandBoundTypes t@(TVar (UVar ref)) = | |
readST ref >>= \case Just t' -> expandBoundTypes t' | |
Nothing -> return t | |
-- The remaining cases are uninteresting: | |
expandBoundTypes (a :-> b) = do t1 <- expandBoundTypes a | |
t2 <- expandBoundTypes b | |
return (t1 :-> t2) | |
expandBoundTypes TInt = return TInt | |
expandBoundTypes TBool = return TBool | |
------------------------------------------------------------------------------------ | |
-- `relabeler mk` is a function that relabels equal `a`s into equal `b`s, keeping | |
-- track of a cache of labels :: [(a, b)]. If the given `a` is not found in the cache, | |
-- it will be freshly created using the monadic action `mk`. | |
-- | |
-- Then, of course, `traverse (relabeler f)` can be used to relabel any tree-like structure: | |
-- | |
-- x 0 <- generated by mk | |
-- / \ / \ | |
-- y x -------> 1 0 <- cache hit! | |
-- / \ / \ | |
-- z y 2 1 <- generated by mk | |
-- | |
relabeler :: (Eq a, MonadState [(a, b)] m) => m b -> (a -> m b) | |
relabeler mk a = do | |
cache <- get | |
case lookup a cache of | |
Just b -> return b | |
Nothing -> do b <- mk | |
modify ((a,b):) -- Put this pair in the cache. | |
return b | |
-- (In practice, we'll only want to relabel stuff conditionally, so we won't *quite* | |
-- write `traverse (relabeler f)`. In `monomorphize` and `polymorphize` below, we write | |
-- `traverse blah`, and `blah` either calls `relabeler` or something else.) | |
------------------------------------------------------------------------------------ | |
-- Monomorphize all the polymorphic type variables in a type; i.e. instantiate them. | |
-- We do this by traversing the type structure with a relabeler that creates a fresh | |
-- unification (monomorphic) variable for every new PVar it encounters. | |
monomorphize :: MonadST m => Type (EitherVar (World m)) -> m (Type (UVar (World m))) | |
monomorphize t = evalStateT (traverse monomorphizeVar t) [] where | |
monomorphizeVar :: (MonadState [(PVar, UVar (World m))] m, MonadST m) | |
=> EitherVar (World m) -> m (UVar (World m)) | |
monomorphizeVar (Monomorphic uvar) = return uvar | |
monomorphizeVar (Polymorphic pvar) = relabeler newUVar pvar | |
-- Generalize over type variables not already in the type environment, making them polymorphic. | |
-- | |
-- This is the raison d'être of `let`! It makes `let id = (\x -> x) in id id 3` typecheck: | |
-- we want `id` to have a polymorphic type, which we obtain by generalizing over the | |
-- unification variable that we assigned to `x`. | |
-- | |
-- But, for example, if we're handling `(\y -> let q = y in q) 2`, we don't want to | |
-- generalize over the UVar for `y`, or we'll incorrectly infer the type of that expression | |
-- as `a` instead of `Int`. To deal with this case, `polymorphize` accepts an `inEnv` | |
-- argument, and leaves all the UVars in that list alone. | |
-- | |
-- I've used constraints on `uvar` and `eithervar` as makeshift aliases. See you all in hell | |
-- | |
polymorphize :: forall m uvar eithervar. | |
(MonadST m, uvar ~ UVar (World m), eithervar ~ EitherVar (World m)) | |
=> [uvar] -> Type uvar -> m (Type eithervar) | |
polymorphize inEnv t = do t' <- expandBoundTypes t | |
evalStateT (traverse generalizeVar t') [] where | |
-- Generalize a UVar (if it isn't in the given environment) by relabeling it with a PVar (Int) value. | |
-- (We use `gets length` as a label-maker, which means that when our cache is [(…,2),(…,1),(…,0)] | |
-- the next value used will be `3`, which is precisely what we want.) | |
generalizeVar :: uvar -> StateT [(uvar, PVar)] m eithervar | |
generalizeVar uvar | uvar `elem` inEnv = return (Monomorphic uvar) | |
| otherwise = Polymorphic <$> relabeler (gets length) uvar | |
-- Return the free UVars in a type environment. | |
freeVarsInEnv :: MonadST m => TypeEnv (World m) -> m [UVar (World m)] | |
freeVarsInEnv env = do | |
-- Reap all the UVars, which may not be expanded... | |
let evars = concatMap toList (M.elems env) | |
let uvars = [uvar | Monomorphic uvar <- evars] | |
-- Expand them all to full-blown types with only free UVars in them, | |
expandedTypes <- mapM (expandBoundTypes . TVar) uvars | |
-- and reap once more. | |
return (concatMap toList expandedTypes) | |
freeVarsInExpandedTerm :: Type (EitherVar s) -> [UVar s] | |
freeVarsInExpandedTerm t = [uvar | Monomorphic uvar <- toList t] | |
-- An infinite list of type names: ["a", "b", ..., "z", "a1", "b1", ...] | |
typeNames :: [String] | |
typeNames = [letter : suf | suf <- "" : map show [1..], letter <- ['a'..'z']] | |
-- Assign names to a polymorphic type. Since PVars are just Ints, we can index into the above list. | |
nameType :: Type PVar -> Type String | |
nameType = fmap (typeNames !!) | |
type Typechecking m = ( -- Type checking/inference will happen in a monad that: | |
MonadST m, -- * can modify the STRefs used for unification, | |
MonadReader (TypeEnv (World m)) m, -- * has access to the type environment, and | |
MonadError InferenceError m) -- * can throw InferenceErrors. | |
-- Here we go! | |
typeof :: Typechecking m => Expr -> m (Type (UVar (World m))) | |
typeof (Lit (LInt _)) = return TInt | |
typeof (Lit (LBool _)) = return TBool | |
typeof (Var v) = -- Look up v in the environment. | |
asks (M.lookup v) >>= \case Just t -> monomorphize t | |
Nothing -> throwError UnknownVariable | |
typeof (App e1 e2) = do | |
t1 <- typeof e1 | |
t2 <- typeof e2 | |
resultType <- TVar <$> newUVar | |
unify t1 (t2 :-> resultType) | |
return resultType | |
typeof (Lam v e) = do | |
vType <- TVar <$> newUVar | |
(vType :->) <$> local (M.insert v (Monomorphic <$> vType)) (typeof e) | |
typeof (Let v e1 e2) = do | |
t1 <- typeof e1 | |
freeVars <- freeVarsInEnv =<< ask | |
p1 <- polymorphize freeVars t1 | |
local (M.insert v p1) (typeof e2) | |
-- Here's `run` for the actual monad transformer stack we use. | |
-- We pass in the empty map as a type environment; we could also predefine a "prelude" instead. | |
runTypechecking :: (forall s. ReaderT (TypeEnv s) (ExceptT InferenceError (ST s)) a) -> Either InferenceError a | |
runTypechecking x = runST $ runExceptT $ runReaderT x M.empty | |
-- Helper function to test our implementation. | |
-- We typecheck x, polymorphize the result, assign names to the PVars and pretty-print the result. | |
test :: Expr -> Either InferenceError String | |
test x = fmap (prettyType . nameType) $ runTypechecking action | |
where action :: (forall s. ReaderT (TypeEnv s) (ExceptT InferenceError (ST s)) (Type PVar)) | |
action = do { t <- typeof x; t' <- polymorphize [] t; return (asPolymorphicType t') } | |
asPolymorphicType = fmap (\(Polymorphic var) -> var) | |
-- The omega combinator. | |
-- Type inference should fail with an InfiniteType error. | |
ω :: Expr | |
ω = Lam "x" (App (Var "x") (Var "x")) | |
-- Non-unifiable: applying an int to a bool. | |
-- Type inference should fail with a Couldn'tUnify error. | |
nonunifiable :: Expr | |
nonunifiable = App (Lit$LInt 3) (Lit$LBool False) | |
-- Unknown variable. | |
-- Type inference should fail with, well, an UnknownVariable error. | |
unknownVariable :: Expr | |
unknownVariable = Lam "x" (Var "y") | |
-- Function composition: `compose f g` applies f, then g. | |
-- Type inference should give `(a -> b) -> (b -> c) -> a -> c`. | |
composition :: Expr | |
composition = Lam "f" (Lam "g" (Lam "x" (App (Var "g") (App (Var "f") (Var "x"))))) | |
-- The examples from the `polymorphize` comment. | |
-- This is `let id = (\x -> x) in id id 3`, and it should have type `Int` (not fail). | |
poly1 :: Expr | |
poly1 = Let "id" (Lam "x" $ Var "x") (App (Var "id") $ App (Var "id") $ Lit (LInt 3)) | |
-- This is `(\y -> let q = y in q) 2`, and it should have type `Int` (not `a`). | |
poly2 :: Expr | |
poly2 = App f (Lit $ LInt 2) where f = Lam "y" (Let "q" (Var "y") (Var "q")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment