Skip to content

Instantly share code, notes, and snippets.

@felko
Last active May 27, 2020 00:49
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save felko/90dfeecfd2795652b8902f6169285481 to your computer and use it in GitHub Desktop.
Save felko/90dfeecfd2795652b8902f6169285481 to your computer and use it in GitHub Desktop.
linear lambda calculus typechecker
executable llc
main-is: Main.hs
-- other-modules:
build-depends: base >=4.12 && <4.13
, containers >=0.6.2.1 && <0.7
, mtl >=2.2 && <2.3
, uuid
, MonadRandom
, these
, semialign
, semialign-indexed
, pretty
hs-source-dirs: src
default-language: Haskell2010
{-# LANGUAGE
LambdaCase
, OverloadedLists
, OverloadedStrings
, RecordWildCards
, BlockArguments
, DeriveFunctor
, TypeApplications
, GeneralizedNewtypeDeriving
#-}
-- https://core.ac.uk/download/pdf/81933277.pdf
module Main where
import Control.Arrow ((>>>))
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Except
import Control.Monad.RWS
import Control.Monad.Random
import Data.Functor
import Data.Monoid
import Data.Maybe
import Data.Function
import Data.List (nub, intercalate)
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import qualified Data.Set as Set
import qualified Data.Map as Map
import Data.These
import Data.Semialign hiding (zip)
import Data.Semialign.Indexed
import Text.PrettyPrint.HughesPJClass hiding ((<>))
import Data.UUID hiding (null)
import Data.Coerce
import Debug.Trace
data Name = Name
{ display :: String
, uid :: UUID }
deriving Show
instance Eq Name where
(==) = (==) `on` uid
instance Ord Name where
compare = compare `on` uid
instance Pretty Name where
pPrintPrec (PrettyLevel 0) _ Name{..} = text display
pPrintPrec _ _ Name{..} = text display <> "@" <> text (show uid)
data Type
= VarT String
| AppT Type Type
| TensorT Type Type
| PlusT Type Type
| WithT Type Type
| LolliT Type Type
| UnitT
| OfCourseT Type
deriving (Eq, Show)
instance Pretty Type where
pPrintPrec l i = \case
VarT n -> text n
TensorT a b -> maybeParens (i >= 3) (pPrintPrec l 3 a <+> "⊗" <+> pPrintPrec l 3 b)
LolliT a b -> maybeParens (i >= 1) (pPrintPrec l 1 a <+> "⊸" <+> pPrintPrec l 0 b)
PlusT a b -> maybeParens (i >= 3) (pPrintPrec l 3 a <+> "⊕" <+> pPrintPrec l 3 b)
WithT a b -> maybeParens (i >= 3) (pPrintPrec l 3 a <+> "&" <+> pPrintPrec l 3 b)
AppT a b -> maybeParens (i >= 2) (pPrintPrec l 1 a <+> pPrintPrec l 3 b)
OfCourseT a -> "!" <> pPrintPrec l 3 a
UnitT -> "1"
data Term
= Var String -- x
| Let String Term Term -- let y = 2 * x in y^8
| Unit -- ⋆
| Empty String Term -- empty x, e
| App Term Term -- f x
| Abs String Term -- λ x ⊸ e
| Pair Term Term -- ⟨x, y⟩
| Choose String String Bool Term -- choose x = tup.fst, e
| Tensor Term Term -- x ⊗ y
| Split String String String Term -- split x ⊗ y = t, e
| Quote Term -- `e`
| Eval String String Term -- eval x = u, e
| Copy String String String Term -- copy (x, y) = u, e
| Ignore String Term -- ignore x, e
| Inl Term | Inr Term
| Case String String Term String Term
-- | Cons String Term -- <Just 1>
-- | Case Term [(String, String, Term)] -- case x of { <Failure err> e | <Success res> f }
deriving Show
instance Pretty Term where
pPrintPrec lvl i = \case
Var n -> text n
Let x y e -> maybeParens (i >= 1) ("let " <> text x <> " = " <> pPrintPrec lvl 0 y <> ", " <> pPrintPrec lvl 0 e)
Unit -> "⋆"
Empty x e -> maybeParens (i >= 1) ("empty " <> text x <> ", " <> pPrintPrec lvl 0 e)
Abs x e -> maybeParens (i >= 1) ("λ " <> text x <> " ⊸ " <> pPrintPrec lvl 0 e)
App f x -> maybeParens (i >= 2) (pPrintPrec lvl 1 f <+> pPrintPrec lvl 1 x)
Pair x y -> "⟨" <> pPrintPrec lvl 0 x <> ", " <> pPrintPrec lvl 0 y <> "⟩"
Choose x p False e -> maybeParens (i >= 1) ("choose " <> text x <> " = " <> text p <> ".fst, " <> pPrintPrec lvl 0 e)
Choose y p True e -> maybeParens (i >= 1) ("choose " <> text y <> " = " <> text p <> ".snd, " <> pPrintPrec lvl 0 e)
Tensor x y -> maybeParens (i >= 1) (pPrintPrec lvl 1 x <> " ⊗ " <> pPrintPrec lvl 1 y)
Split x y z e -> maybeParens (i >= 1) ("split " <> text x <> " ⊗ " <> text y <> " = " <> text z <> ", " <> pPrintPrec lvl 0 e)
Quote x -> "`" <> pPrintPrec lvl 0 x <> "`"
Eval x u e -> maybeParens (i >= 1) ("eval " <> text x <> " = " <> text u <> ", " <> pPrintPrec lvl 0 e)
Copy x y z e -> maybeParens (i >= 1) ("copy " <> parens (text x <> ", " <> text y) <> " = " <> text z <> ", " <> pPrintPrec lvl 0 e)
Ignore x e -> maybeParens (i >= 1) ("ignore " <> text x <> ", " <> pPrintPrec lvl 0 e)
Inl x -> "<Left " <> pPrintPrec lvl 2 x <> ">"
Inr x -> "<Right " <> pPrintPrec lvl 2 x <> ">"
Case e x p y q -> maybeParens (i >= 1) ("case " <> text e <> braces ("Left " <> text x <> " → " <> pPrintPrec lvl 0 p <> " | " <> "Right " <> text y <> " → " <> pPrintPrec lvl 0 q))
type Scope = Map.Map String Name
newtype Context = Context
{ getCtx :: Map.Map Name Type }
deriving (Show, Semigroup, Monoid)
instance Pretty Context where
pPrintPrec lvl _ (Context m) = cat . punctuate ", " $ pAssoc <$> Map.assocs m
where pAssoc (n, t) = pPrintPrec lvl 0 n <+> ":" <+> pPrint t
introduce :: Name -> Type -> Context -> Context
introduce n t (Context ctx) = Context (Map.insert n t ctx)
consume :: Name -> Context -> Context
consume n (Context ctx) = Context (Map.delete n ctx)
data CheckError
= ScopeError String
| UnboundError String
| OverlapError Context
| TypeError Type Type
| UnusedError Context
| OccursCheckError String Type
deriving Show
data CheckState = CheckState
{ tyVarSupply :: Int }
deriving Show
data Constraint
= Type :~ Type
deriving Show
data Judgement = Context :⊢ Type
deriving Show
instance Pretty Judgement where
pPrintPrec lvl _ (ctx :⊢ t) = pPrintPrec lvl 0 ctx <+> "⊢" <+> pPrint t
type Check a =
RWST
(Map.Map String Name)
[Constraint]
CheckState
(RandT StdGen (Except (Last (NonEmpty CheckError))))
a
lookupCtx :: Name -> Context -> Check Type
lookupCtx n (Context ctx) = maybe (checkError (ScopeError (display n))) pure (Map.lookup n ctx)
mergeCtx :: Context -> Context -> Check Context
mergeCtx (Context ctx) (Context ctx')
| Map.disjoint ctx ctx' = pure (Context (Map.union ctx ctx'))
| otherwise = checkError (OverlapError (Context (Map.intersection ctx ctx')))
unifyCtx :: Context -> Context -> Check Context
unifyCtx (Context ctx) (Context ctx') = Context <$> sequence (ialignWith f ctx ctx')
where f n (These a b) = require (a :~ b) *> pure a
f n (This a) = checkError (ScopeError (display n))
f n (That a) = checkError (ScopeError (display n))
unrestrictedCtx :: Context -> Check ()
unrestrictedCtx = getCtx >>> mapM_ \ t -> do
t' <- freshTyVar
require (t :~ OfCourseT t')
checkError :: CheckError -> Check a
checkError err = throwError (pure [err])
require :: Constraint -> Check ()
require c = tell [c]
unique :: String -> Check Name
unique s = Name s <$> getRandom @_ @UUID
bound :: String -> Check Name
bound s = asks (Map.lookup s) >>= \case
Just n -> pure n
Nothing -> checkError (ScopeError s)
freshTyVar :: Check Type
freshTyVar = gets tyVarSupply >>= \ i -> do
modify \ st -> st { tyVarSupply = i + 1 }
pure (VarT ('$':show i))
debug :: (Term -> Check Judgement) -> Term -> Check Judgement
debug chk term = chk term >>= \ j@(ctx :⊢ t) ->
trace (render (pPrint ctx <+> "⊢" <+> pPrint term <+> ":" <+> pPrint t)) $ pure j
check :: Term -> Check Judgement
check = debug \case
Var s -> do
n <- bound s
t <- freshTyVar
pure (Context (Map.singleton n t) :⊢ t)
Let x y e -> do
xn <- unique x
ctxy :⊢ a <- check y
ctxe :⊢ t <- local (Map.insert x xn) (check e)
a' <- lookupCtx xn ctxe
require (a :~ a')
ctx <- mergeCtx ctxy (consume xn ctxe)
pure (ctx :⊢ t)
Unit -> pure (mempty :⊢ UnitT)
Empty x e -> do
xn <- bound x
ctx :⊢ t <- check e
pure (introduce xn UnitT ctx :⊢ t)
Tensor x y -> do
ctx1 :⊢ a <- check x
ctx2 :⊢ b <- check y
ctx <- mergeCtx ctx1 ctx2
pure (ctx :⊢ TensorT a b)
Split x y z e -> do
(xn, yn) <- (,) <$> unique x <*> unique y
ctx :⊢ t <- local (Map.insert x xn . Map.insert y yn) (check e)
(a, b) <- (,) <$> lookupCtx xn ctx <*> lookupCtx yn ctx
zn <- bound z
let upd = introduce zn (TensorT a b) . consume xn . consume yn
pure (upd ctx :⊢ t)
Abs x e -> do
xn <- unique x
ctx :⊢ b <- local (Map.insert x xn) (check e)
a <- lookupCtx xn ctx
pure (consume xn ctx :⊢ LolliT a b)
App f x -> do
(ctxf :⊢ tf, ctxx :⊢ tx) <- (,) <$> check f <*> check x
ty <- freshTyVar
require (tf :~ LolliT tx ty)
ctx <- mergeCtx ctxf ctxx
pure (ctx :⊢ ty)
Pair x y -> do
(ctxx :⊢ tx, ctxy :⊢ ty) <- (,) <$> check x <*> check y
ctx <- unifyCtx ctxx ctxy
pure (ctx :⊢ WithT tx ty)
Choose y p False e -> do
yn <- unique y
pn <- bound p
ctx :⊢ t <- local (Map.insert y yn) (check e)
a <- lookupCtx yn ctx
b <- freshTyVar
pure (introduce pn (WithT a b) (consume yn ctx) :⊢ t)
Choose x p True e -> do
xn <- unique x
pn <- bound p
ctx :⊢ t <- local (Map.insert x xn) (check e)
a <- freshTyVar
b <- lookupCtx xn ctx
pure (introduce pn (WithT a b) (consume xn ctx) :⊢ t)
Inl x -> do
ctx :⊢ a <- check x
b <- freshTyVar
pure (ctx :⊢ PlusT a b)
Inr y -> do
ctx :⊢ b <- check y
a <- freshTyVar
pure (ctx :⊢ PlusT a b)
Case x l p r q -> do
xn <- bound x
(ln, rn) <- (,) <$> unique l <*> unique r
ctxp :⊢ tp <- local (Map.insert l ln) (check p)
ctxq :⊢ tq <- local (Map.insert r rn) (check q)
require (tp :~ tq)
(a, b) <- (,) <$> lookupCtx ln ctxp <*> lookupCtx rn ctxq
ctx <- unifyCtx (consume ln ctxp) (consume rn ctxq)
pure (introduce xn (PlusT a b) ctx :⊢ tp)
Copy x y z e -> do
zn <- bound z
(xn, yn) <- (,) <$> unique x <*> unique y
ctx :⊢ t <- local (Map.insert x xn . Map.insert y yn) (check e)
(ax, ay) <- (,) <$> lookupCtx xn ctx <*> lookupCtx yn ctx
bangA <- OfCourseT <$> freshTyVar
require (bangA :~ ax)
require (bangA :~ ay)
let upd = introduce zn bangA . consume xn . consume yn
pure (upd ctx :⊢ t)
Quote e -> do
ctx :⊢ t <- check e
unrestrictedCtx ctx
pure (ctx :⊢ OfCourseT t)
Eval x u e -> do
xn <- unique x
un <- bound u
ctx :⊢ t <- local (Map.insert x xn) (check e)
a <- lookupCtx xn ctx
pure (introduce un (OfCourseT a) (consume xn ctx) :⊢ t)
runCheck :: Term -> ExceptT [CheckError] IO (Type, [Constraint])
runCheck term = do
g <- liftIO newStdGen
let tr (Identity (Left (Last Nothing))) = pure (Left [])
tr (Identity (Left (Last (Just errs)))) = pure (Left (NE.toList errs))
tr (Identity (Right res)) = pure (Right res)
initialState = CheckState { tyVarSupply = 0 }
(Context ctx :⊢ typ, cs) <- mapExceptT tr (evalRandT (evalRWST (check term) mempty initialState) g)
if Map.null ctx then
pure (typ, cs)
else
throwError (UnboundError . display <$> Map.keys ctx)
newtype Subst = Subst
{ getSubst :: Map.Map String Type }
deriving Show
instance Pretty Subst where
pPrint (Subst m) = braces . cat . punctuate ", " $ pAssoc <$> Map.assocs m
where pAssoc (v, t) = text v <+> "⇒" <+> pPrint t
instance Semigroup Subst where
Subst s1 <> Subst s2 = Subst (Map.map (substitute (Subst s2)) s1 <> s2)
instance Monoid Subst where
mempty = Subst mempty
freeTyVars :: Type -> [String]
freeTyVars = nub . go
where go = \case
VarT n -> [n] :: [String]
TensorT a b -> go a <> go b
LolliT a b -> go a <> go b
PlusT a b -> go a <> go b
WithT a b -> go a <> go b
OfCourseT a -> go a
AppT a b -> go a <> go b
UnitT -> []
substitute :: Subst -> Type -> Type
substitute s@(Subst m) = \case
VarT n -> fromMaybe (VarT n) (Map.lookup n m)
TensorT a b -> TensorT (substitute s a) (substitute s b)
LolliT a b -> LolliT (substitute s a) (substitute s b)
PlusT a b -> PlusT (substitute s a) (substitute s b)
WithT a b -> WithT (substitute s a) (substitute s b)
OfCourseT a -> OfCourseT (substitute s a)
UnitT -> UnitT
AppT a b -> AppT (substitute s a) (substitute s b)
subst1 :: String -> Type -> Either CheckError Subst
subst1 n t
| n `elem` freeTyVars t = throwError (OccursCheckError n t)
| otherwise = pure (Subst (Map.singleton n t))
unify :: Type -> Type -> Either CheckError Subst
unify = curry \case
(VarT n, b) -> subst1 n b
(a, VarT n) -> subst1 n a
(TensorT a b, TensorT a' b') -> (<>) <$> unify a a' <*> unify b b'
(LolliT a b, LolliT a' b') -> (<>) <$> unify a a' <*> unify b b'
(PlusT a b, PlusT a' b') -> (<>) <$> unify a a' <*> unify b b'
(WithT a b, WithT a' b') -> (<>) <$> unify a a' <*> unify b b'
(AppT a b, AppT a' b') -> (<>) <$> unify a a' <*> unify b b'
(OfCourseT a, OfCourseT b) -> unify a b
(UnitT, UnitT) -> pure mempty
(a, b) -> throwError (TypeError a b)
mergeSubsts :: Subst -> Subst -> Either CheckError Subst
mergeSubsts s1@(Subst m1) s2@(Subst m2) = mappend (s1 <> s2) . foldMap id <$> (sequence (Map.intersectionWith unify m1 m2))
solve :: [Constraint] -> ([CheckError], Subst)
solve cs = foldr merge ([], mempty) solutions
where solutions = cs <&> \case
a :~ b -> unify a b
merge (Left err) (errs, subst) = (err : errs, subst)
merge (Right s) (errs, subst) = case mergeSubsts s subst of
Left err -> (err : errs, subst <> s)
Right s' -> (errs, s')
infer :: Term -> ExceptT [CheckError] IO Type
infer term = do
(t, cs) <- runCheck term
let (errs, subst) = solve cs
-- liftIO (putStrLn (prettyShow t) >> print subst >> putStrLn (prettyShow (substitute subst t)))
if null errs then
let t' = substitute subst t
tvs = freeTyVars t'
prettyVars = (:[]) <$> ['A'..'Z']
subst' = Subst (Map.fromList (zip tvs (VarT <$> prettyVars)))
in pure (substitute subst' t')
else
throwError errs
tensorAssoc :: Term
tensorAssoc = Abs "xy_z" (Split "xy" "z" "xy_z"
(Split "x" "y" "xy" (Tensor (Var "x")
(Tensor (Var "y") (Var "z")))))
boolIndex :: Term
boolIndex = Abs "p" (Abs "b" (Case "b"
"u" (Empty "u" (Choose "x" "p" False (Var "x")))
"u" (Empty "u" (Choose "x" "p" True (Var "x")))))
exponentialMap :: Term
exponentialMap =
Abs "u" (Copy "x" "y" "u" (Tensor
(Quote (Eval "p" "x" (Choose "a" "p" False (Var "a"))))
(Quote (Eval "q" "y" (Choose "b" "q" True (Var "b"))))))
{-
A ⊗ (B ⅋ C) ⊸ ((A ⊗ B) ⅋ C)
= A ⊗ (~B ⊸ C) ⊸ (A ⊸ ~B) ⊸ C
= A ⊗ (B ⊸ C) ⊸ (A ⊸ B) ⊸ C
-}
linearDistribution :: Term
linearDistribution =
Abs "af" (Abs "g"
(Split "a" "f" "af" (App (Var "f") (App (Var "g") (Var "a")))))
twice :: Term
twice = Abs "f" (Abs "x"
(Copy "f1" "f2" "f"
(Eval "f1e" "f1"
(Eval "f2e" "f2"
(App (Var "f2e")
(App (Var "f1e") (Var "x")))))))
{-}
testLet = Abs "x" (Abs "f"
(Copy "f1" "f2" "f" (Eval "f1e" "f1" (Eval "f2e" "f2" (Let "y" (App (Var "f1e") (Var "x")) (App (Var "f2e") (Var "y")))))))
-}
test :: Term -> IO ()
test term = runExceptT (infer term) >>= \case
Left errs -> mapM_ print errs
Right typ -> putStrLn (prettyShow typ)
main :: IO ()
main = do
mapM_ @[] test [tensorAssoc, boolIndex, exponentialMap, linearDistribution, twice]
-- print (mergeSubsts (Subst (Map.fromList [("$0", LolliT (VarT "$1") (VarT "$2"))])) (Subst (Map.fromList [("$0", LolliT UnitT (OfCourseT (VarT "$3")))])))
{-
⊗-assoc : ∀ A B C. (A ⊗ B) ⊗ C ⊸ A ⊗ (B ⊗ C)
⊗-assoc = λ xy_z ⊸ split (xy, z) = xy_z,
split (x, y) = xy,
x ⊗ (y ⊗ z)
bool-index : ∀ A. A & A ⊸ 1 ⊕ 1 ⊸ A
bool-index = λ p ⊸ λ b ⊸ case b of
true -> choose x = p.fst, x
false -> choose y = p.snd, y
exponentialMap : ∀ A B. !(A & B) ⊸ !A ⊗ !B
exponentialMap = λ u ⊸
copy (x, y) = u, `eval p = x, choose a = p.fst, a`
⊗ `eval q = y, choose b = p.snd, b`
linearDistribution : ∀ A B C. A ⊗ (B ⊸ C) ⊸ (A ⊸ B) ⊸ C
linearDistribution = λ af ⊸ λ g ⊸ split (a, f) = af, f (g a)
-}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment