Skip to content

Instantly share code, notes, and snippets.

@lexi-lambda
Last active July 5, 2023 18:01
Show Gist options
  • Star 20 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save lexi-lambda/287dc8513f6a20424457b9d3eda5026a to your computer and use it in GitHub Desktop.
Save lexi-lambda/287dc8513f6a20424457b9d3eda5026a to your computer and use it in GitHub Desktop.
Minimal Haskell implementation of Complete and Easy Bidirectional Typechecking for Higher-Rank Polymorphism
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Language.HigherRank.Main
( Expr(..)
, EVar(..)
, Type(..)
, TVar(..)
, TEVar(..)
, runInfer
) where
import qualified Data.Sequence as S
import Control.Monad (unless)
import Control.Monad.Except (MonadError, ExceptT, runExceptT, throwError)
import Control.Monad.State (MonadState, State, evalState, get, gets, put, modify)
import Data.Foldable (toList)
import Data.Maybe (isJust)
import Data.Monoid ((<>))
import Data.Sequence (Seq)
newtype EVar = MkEVar { unEVar :: String }
deriving (Eq, Ord, Show)
data Expr
= EUnit
| EVar EVar
| EAnn Expr Type
| ELam EVar Expr
| EApp Expr Expr
deriving (Eq, Ord, Show)
newtype TVar = MkTVar { unTVar :: String }
deriving (Eq, Ord, Show)
newtype TEVar = MkTEVar { unTEVar :: String }
deriving (Eq, Ord, Show)
data Type
= TUnit
| TVar TVar
| TEVar TEVar
| TArr Type Type
| TAll TVar Type
deriving (Eq, Ord, Show)
isMono :: Type -> Bool
isMono TUnit = True
isMono (TVar _) = True
isMono (TEVar _) = True
isMono (TArr a b) = isMono a && isMono b
isMono (TAll _ _) = False
data CtxMember
= CtxVar TVar
| CtxAssump EVar Type
| CtxEVar TEVar
| CtxSolved TEVar Type
| CtxMarker TEVar
deriving (Eq, Ord, Show)
newtype Ctx = Ctx (Seq CtxMember)
deriving (Eq, Show, Monoid)
(|>) :: Ctx -> CtxMember -> Ctx
(Ctx ctx) |> mem = Ctx (ctx S.|> mem)
ctxElem :: CtxMember -> Ctx -> Bool
ctxElem x (Ctx ctx) = x `elem` ctx
ctxHole :: CtxMember -> Ctx -> Maybe (Ctx, Ctx)
ctxHole mem (Ctx ctx) = if mem `elem` ctx then Just (Ctx a, Ctx (S.drop 1 b)) else Nothing
where (a, b) = S.breakl (== mem) ctx
ctxHole2 :: CtxMember -> CtxMember -> Ctx -> Maybe (Ctx, Ctx, Ctx)
ctxHole2 mem mem' ctx = do
(a, ctx') <- ctxHole mem ctx
(b, c) <- ctxHole mem' ctx'
return (a, b, c)
ctxAssump :: Ctx -> EVar -> Maybe Type
ctxAssump (Ctx ctx) x = case assumptions of
[CtxAssump _ t] -> Just t
[] -> Nothing
other -> error $ "ctxSolution: internal error — multiple types for variable: " ++ show other
where isAssump (CtxAssump y _) = x == y
isAssump _ = False
assumptions = filter isAssump $ toList ctx
ctxSolution :: Ctx -> TEVar -> Maybe Type
ctxSolution (Ctx ctx) v = case solutions of
[CtxSolved _ t] -> Just t
[] -> Nothing
other -> error $ "ctxSolution: internal error — multiple solutions for variable: " ++ show other
where isSolution (CtxSolved u _) = v == u
isSolution _ = False
solutions = filter isSolution $ toList ctx
ctxUntil :: CtxMember -> Ctx -> Ctx
ctxUntil m (Ctx ctx) = Ctx $ S.takeWhileL (/= m) ctx
typeWF, (⊢) :: Ctx -> Type -> Either String ()
typeWF _ TUnit = return ()
typeWF ctx (TVar v) = unless (CtxVar v `ctxElem` ctx) $ Left $ "unbound type variable ‘" ++ unTVar v ++ "’"
typeWF ctx (TEVar v) = unless (CtxEVar v `ctxElem` ctx || hasSolution) $ Left $ "unbound existential variable ‘" ++ unTEVar v ++ "’"
where hasSolution = isJust (ctxSolution ctx v)
typeWF ctx (TArr x y) = typeWF ctx x >> typeWF ctx y
typeWF ctx (TAll v t) = typeWF (ctx |> CtxVar v) t
(⊢) = typeWF
freeVars :: Type -> [TEVar]
freeVars TUnit = []
freeVars (TVar _) = []
freeVars (TEVar v) = [v]
freeVars (TArr a b) = freeVars a <> freeVars b
freeVars (TAll _ t) = freeVars t
applySubst :: Ctx -> Type -> Type
applySubst _ TUnit = TUnit
applySubst _ t@(TVar _) = t
applySubst ctx t@(TEVar v) = maybe t (applySubst ctx) (ctxSolution ctx v)
applySubst ctx (TArr a b) = TArr (applySubst ctx a) (applySubst ctx b)
applySubst ctx (TAll v t) = TAll v (applySubst ctx t)
inst :: (TVar, Type) -> Type -> Type
inst _ TUnit = TUnit
inst (v, s) t@(TVar v')
| v == v' = s
| otherwise = t
inst _ t@(TEVar _) = t
inst s (TArr a b) = TArr (inst s a) (inst s b)
inst s (TAll v t) = TAll v (inst s t)
--------------------------------------------------------------------------------
data CheckState = CheckState
{ checkCtx :: Ctx
, checkNextEVar :: Integer
} deriving (Eq, Show)
defCheckState :: CheckState
defCheckState = CheckState mempty 1
getCtx :: CheckM Ctx
getCtx = gets checkCtx
putCtx :: Ctx -> CheckM ()
putCtx ctx = get >>= \s -> put s { checkCtx = ctx }
modifyCtx :: (Ctx -> Ctx) -> CheckM ()
modifyCtx f = putCtx . f =<< getCtx
freshEVar :: CheckM TEVar
freshEVar = MkTEVar . ("a" ++) . show <$> gets checkNextEVar
<* modify (\s -> s { checkNextEVar = checkNextEVar s + 1 })
checkTypeWF :: Type -> CheckM ()
checkTypeWF t = getCtx >>= \ctx -> either throwError return (typeWF ctx t)
newtype CheckM a = CheckM (ExceptT String (State CheckState) a)
deriving (Functor, Applicative, Monad, MonadState CheckState, MonadError String)
runCheckM :: CheckM a -> Either String a
runCheckM (CheckM x) = evalState (runExceptT x) defCheckState
tySub :: Type -> Type -> CheckM ()
tySub TUnit TUnit = return ()
tySub (TVar a) (TVar b) | a == b = return ()
tySub (TEVar a) (TEVar b) | a == b = return ()
tySub (TArr a b) (TArr a' b') = tySub a' a >> tySub b b'
tySub (TAll v a) b = do
â <- freshEVar
let a' = inst (v, TEVar â) a
modifyCtx (\c -> c |> CtxMarker â |> CtxEVar â)
tySub a' b
modifyCtx (ctxUntil (CtxMarker â))
tySub a (TAll v b) = do
modifyCtx (|> CtxVar v)
tySub a b
modifyCtx (ctxUntil (CtxVar v))
tySub (TEVar â) a | â `notElem` freeVars a = instL â a
tySub a (TEVar â) | â `notElem` freeVars a = instR a â
tySub a b = throwError $ "type mismatch: expected " ++ show b ++ ", given " ++ show a
instL :: TEVar -> Type -> CheckM ()
instL â t = getCtx >>= go where
-- Defer to a helper function so we can pattern match/guard against the
-- current context.
go ctx -- InstLSolve
| True <- isMono t
, Just (l, r) <- ctxHole (CtxEVar â) ctx
, Right _ <- l ⊢ t
= putCtx $ l |> CtxSolved â t <> r
go ctx -- InstLReach
| TEVar â' <- t
, Just (l, m, r) <- ctxHole2 (CtxEVar â) (CtxEVar â') ctx
= putCtx $ l |> CtxEVar â <> m |> CtxSolved â' (TEVar â) <> r
go ctx -- InstLArr
| Just (l, r) <- ctxHole (CtxEVar â) ctx
, TArr a b <- t
= do â1 <- freshEVar
â2 <- freshEVar
putCtx $ l |> CtxEVar â2 |> CtxEVar â1 |> CtxSolved â (TArr (TEVar â1) (TEVar â2)) <> r
instR a â1
ctx' <- getCtx
instL â2 (applySubst ctx' b)
go ctx -- InstLArrR
| TAll b s <- t
= do putCtx $ ctx |> CtxVar b
instL â s
Just (ctx', _) <- ctxHole (CtxVar b) <$> getCtx
putCtx ctx'
go _ = error $ "instL: failed to instantiate " ++ show â ++ " to " ++ show t
instR :: Type -> TEVar -> CheckM ()
instR t â = getCtx >>= go where
-- Defer to a helper function so we can pattern match/guard against the
-- current context.
go ctx -- InstRSolve
| True <- isMono t
, Just (l, r) <- ctxHole (CtxEVar â) ctx
, Right _ <- l ⊢ t
= putCtx $ l |> CtxSolved â t <> r
go ctx -- InstRReach
| TEVar â' <- t
, Just (l, m, r) <- ctxHole2 (CtxEVar â) (CtxEVar â') ctx
= putCtx $ l |> CtxEVar â <> m |> CtxSolved â' (TEVar â) <> r
go ctx -- InstRArr
| Just (l, r) <- ctxHole (CtxEVar â) ctx
, TArr a b <- t
= do â1 <- freshEVar
â2 <- freshEVar
putCtx $ l |> CtxEVar â2 |> CtxEVar â1 |> CtxSolved â (TArr (TEVar â1) (TEVar â2)) <> r
instL â1 a
ctx' <- getCtx
instR (applySubst ctx' b) â2
go ctx -- InstRArrL
| TAll b s <- t
= do â' <- freshEVar
putCtx $ ctx |> CtxMarker â' |> CtxEVar â'
instR (inst (b, TEVar â') s) â
Just (ctx', _) <- ctxHole (CtxMarker â') <$> getCtx
putCtx ctx'
go _ = error $ "instR: failed to instantiate " ++ show â ++ " to " ++ show t
check :: Expr -> Type -> CheckM ()
check EUnit TUnit = return ()
check e (TAll v a) = do
modifyCtx (|> CtxVar v)
check e a
modifyCtx (ctxUntil (CtxVar v))
check (ELam x e) (TArr a b) = do
modifyCtx (|> CtxAssump x a)
check e b
modifyCtx (ctxUntil (CtxAssump x a))
check e b = do
a <- infer e
ctx <- getCtx
tySub (applySubst ctx a) (applySubst ctx b)
infer :: Expr -> CheckM Type
infer EUnit = return TUnit
infer (EVar x) = do
ctx <- getCtx
maybe (throwError $ "unbound variable " ++ show x) return (ctxAssump ctx x)
infer (EAnn e a) = checkTypeWF a >> check e a >> return a
infer (ELam x e) = do
â <- freshEVar
â' <- freshEVar
modifyCtx (\c -> c |> CtxEVar â |> CtxEVar â' |> CtxAssump x (TEVar â))
check e (TEVar â')
modifyCtx (ctxUntil (CtxAssump x (TEVar â)))
return $ TArr (TEVar â) (TEVar â')
infer (EApp e1 e2) = do
a <- infer e1
ctx <- getCtx
inferApp (applySubst ctx a) e2
inferApp :: Type -> Expr -> CheckM Type
inferApp (TAll v a) e = do
â <- freshEVar
modifyCtx (|> CtxEVar â)
inferApp (inst (v, TEVar â) a) e
inferApp (TEVar â) e = do
â1 <- freshEVar
â2 <- freshEVar
modifyCtx (\c -> c |> CtxEVar â2 |> CtxEVar â1 |> CtxSolved â (TArr (TEVar â1) (TEVar â2)))
check e (TEVar â1)
return $ TEVar â2
inferApp (TArr a c) e = check e a >> return c
inferApp t e = throwError $ "cannot apply expression of type " ++ show t ++ " to expression " ++ show e
runInfer :: Expr -> Either String Type
runInfer e = runCheckM $ infer e
@lazear
Copy link

lazear commented May 30, 2020

Just commenting to let you know that this was incredibly helpful to me. Even as a Haskell noob, I was able to follow cross compare this with the paper to clear up some of the confusing aspects. Thanks!

@bts
Copy link

bts commented Jan 3, 2022

For anyone following along at home, inferApp for TEVar should be:

inferApp (TEVar â) e = do
  â1 <- freshEVar
  â2 <- freshEVar
  ctx <- getCtx
  let Just (l, r) = ctxHole (CtxEVar â) ctx
  putCtx $ l |> CtxEVar â2 |> CtxEVar â1 |> CtxSolved â (TArr (TEVar â1) (TEVar â2)) <> r
  check e (TEVar â1)
  return $ TEVar â2

per this commit in https://github.com/lexi-lambda/higher-rank

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment