Skip to content

Instantly share code, notes, and snippets.

@christiaanb
Last active August 29, 2015 14:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save christiaanb/4e0fd1777aee927999aa to your computer and use it in GitHub Desktop.
Save christiaanb/4e0fd1777aee927999aa to your computer and use it in GitHub Desktop.
Type Inference for a small dependently typed language
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable,
ScopedTypeVariables #-}
{- |
Usage:
>>> inferType1 (const undefined) dollar
>>> inferType2 (const undefined) dollar
>>> inferType3 (const undefined) dollar
>>> inferType4 (const undefined) id dollar
-}
module DepLamInfer where
import Bound
import Bound.Name
import Bound.Var
import Control.Applicative
import Control.Comonad
import Control.Monad
import Data.Foldable
import Data.Functor.Yoneda
import Data.Traversable
import Prelude.Extras
data Binder n a
= Lam {binder :: Name n a}
| Pi {binder :: Name n a}
deriving (Functor,Foldable,Traversable,Show,Eq)
instance Comonad (Binder n) where
extract = extract . binder
extend f w = fmap (const (f w)) w
data Term n a
= Var !a
| Universe Integer
| App !(Term n a) !(Term n a)
| Bind !(Binder n (Term n a)) !(Scope (Name n ()) (Term n) a)
deriving (Functor,Foldable,Traversable,Show,Eq)
instance Show n => Show1 (Term n)
instance Eq1 (Term n)
instance Applicative (Term n) where
pure = Var
(<*>) = ap
instance Monad (Term n) where
return = Var
(>>=) = bindTerm
bindTerm :: Term n a -> (a -> Term n b) -> Term n b
bindTerm (Var a) f = f a
bindTerm (Universe i) _ = Universe i
bindTerm (App e1 e2) f = App (bindTerm e1 f) (bindTerm e2 f)
bindTerm (Bind b s) f = Bind (fmap (`bindTerm` f) b) (s >>>= f)
inferPi :: Eq a
=> (Term n a -> Term n a) -- ^ Type inference function
-> Term n a
-> (n,Term n a,Scope (Name n ()) (Term n) a)
inferPi inferTy tm = case inferTy tm of
Bind (Pi b) s -> (name b, extract b, s)
t -> error "Function expected"
inferUniverse :: Eq a
=> (Term n a -> Term n a) -- ^ Type inference function
-> Term n a
-> Integer
inferUniverse inferTy ty = case inferTy ty of
Universe i -> i
t -> error "Type expected"
-- | inferType1: many traversals using 'toScope' and 'fromScope'
inferType1 :: Eq a
=> (a -> Term n a) -- ^ Context
-> Term n a -- ^ Term
-> Term n a -- ^ Inferred type
inferType1 ctx (Var a) = ctx a
inferType1 _ (Universe i) = Universe (i+1)
inferType1 ctx (App e1 e2) = if s == te then instantiate1Name e2 t
else error "Mismatch"
where
te = inferType1 ctx e2
(_,s,t) = inferPi (inferType1 ctx) e1
inferType1 ctx (Bind (Pi b) s) = Universe (max k1 k2)
where
t = extract b
k1 = inferUniverse (inferType1 ctx) t
k2 = inferUniverse (inferType1 ctx) (instantiate1Name t s)
inferType1 ctx (Bind (Lam b) s) = Bind (Pi b) s'
where
s' = toScope . inferType1 ctx' . fromScope $ s
ctx' = unvar bCtx fCtx
bCtx _ = fmap F . extract $ b
fCtx = fmap F . ctx
-- | inferType2: Only traversals in new context
inferType2 :: Eq a
=> (a -> Term n a) -- ^ Context
-> Term n a -- ^ Term
-> Term n a -- ^ Inferred type
inferType2 ctx (Var a) = ctx a
inferType2 _ (Universe i) = Universe (i+1)
inferType2 ctx (App e1 e2) = if s == te then instantiate1Name e2 t
else error "Mismatch"
where
te = inferType2 ctx e2
(_,s,t) = inferPi (inferType2 ctx) e1
inferType2 ctx (Bind (Pi b) s) = Universe (max k1 k2)
where
t = extract b
k1 = inferUniverse (inferType2 ctx) t
k2 = inferUniverse (inferType2 ctx) (instantiate1Name t s)
inferType2 ctx (Bind (Lam b) s) = Bind (Pi b) s'
where
s' = Scope . inferType2 ctx' . unscope $ s
ctx' = unvar bCtx fCtx
bCtx _ = fmap (F . Var) . extract $ b
fCtx = fmap (F . Var) . inferType2 ctx
-- | inferType3: no traversals in context (non-solution)
inferType3 :: Eq a
=> (a -> Term n a) -- ^ Context
-> Term n a -- ^ Term
-> Term n a -- ^ Inferred type
inferType3 ctx (Var a) = ctx a
inferType3 _ (Universe i) = Universe (i+1)
inferType3 ctx (App e1 e2) = if s == te then instantiate1Name e2 t
else error "Mismatch"
where
te = inferType3 ctx e2
(_,s,t) = inferPi (inferType3 ctx) e1
inferType3 ctx (Bind (Pi b) s) = Universe (max k1 k2)
where
t = extract b
k1 = inferUniverse (inferType3 ctx) t
k2 = inferUniverse (inferType3 ctx) (instantiate1Name t s)
inferType3 ctx (Bind (Lam b) s) = Bind (Pi b) s'
where
s' = Scope . inferType3 ctx' . unscope $ s
ctx' = unvar bCtx fCtx
bCtx _ = Var . F . extract $ b
fCtx = Var . F . inferType3 ctx
-- | inferType4: delay traversal of the context (no better than inferType2)
inferType4 :: Eq a
=> (a -> Term n a) -- ^ Context
-> (Term n a -> Term n a) -- Quotient and distribute (Var . F)
-> Term n a -- ^ Term
-> Term n a -- ^ Inferred type
inferType4 ctx dist (Var a) = dist (ctx a)
inferType4 _ dist (Universe i) = Universe (i+1)
inferType4 ctx dist (App e1 e2) = if s == te then instantiate1Name e2 t
else error "Mismatch"
where
te = inferType4 ctx dist e2
(_,s,t) = inferPi (inferType4 ctx dist) e1
inferType4 ctx dist (Bind (Pi b) s) = Universe (max k1 k2)
where
t = extract b
k1 = inferUniverse (inferType4 ctx dist) t
k2 = inferUniverse (inferType4 ctx dist) (instantiate1Name t s)
inferType4 ctx dist (Bind (Lam b) s) = Bind (Pi b) s'
where
s' = Scope . inferType4 ctx' dist' . unscope $ s
ctx' = unvar bCtx fCtx
bCtx _ = Var . F . extract $ b
fCtx = Var . F . inferType4 ctx dist
dist' (Var (F a)) = fmap (F . Var) . dist $ a
dist' e = e
-- | inferType5: delay traversal of the context, trying to get Yoneda to merge
-- the 'fmap (F . Var)' created by dist'. Hopefully better than 'inferType2'
type YTerm n a = Yoneda (Term n) a
inferType5 :: forall a n . Eq a
=> (a -> Term n a) -- ^ Context
-> (Term n a -> YTerm n a) -- Quotient and distribute (Var . F)
-> Term n a -- ^ Term
-> Term n a -- ^ Inferred type
inferType5 ctx dist (Var a) = lowerYoneda . dist . ctx $ a
inferType5 _ dist (Universe i) = Universe (i+1)
inferType5 ctx dist (App e1 e2) = if s == te then instantiate1Name e2 t
else error "Mismatch"
where
te = inferType5 ctx dist e2
(_,s,t) = inferPi (inferType5 ctx dist) e1
inferType5 ctx dist (Bind (Pi b) s) = Universe (max k1 k2)
where
t = extract b
k1 = inferUniverse (inferType5 ctx dist) t
k2 = inferUniverse (inferType5 ctx dist) (instantiate1Name t s)
inferType5 ctx dist (Bind (Lam b) s) = Bind (Pi b) s'
where
s' = Scope . inferType5 ctx' dist' . unscope $ s
ctx' = unvar bCtx fCtx
bCtx _ = Var . F . extract $ b
fCtx = Var . F . inferType5 ctx dist
dist' :: Term n (Var (Name n ()) (Term n a)) -> YTerm n (Var (Name n ()) (Term n a))
dist' (Var (F a)) = fmap (F . Var) . dist $ a
dist' e = liftYoneda e
-- Example:
type LVar = String
type CoreTerm = Term LVar LVar
uni :: CoreTerm
uni = Universe 0
lam :: (LVar,CoreTerm) -> CoreTerm -> CoreTerm
lam (v,b) e = Bind (Lam (Name v b)) (abstract1Name v e)
pi_ :: (LVar,CoreTerm) -> CoreTerm -> CoreTerm
pi_ (v,b) e = Bind (Pi (Name v b)) (abstract1Name v e)
dollar :: CoreTerm
dollar = lam ("a",uni) $ lam ("b",uni)
$ lam ("f",pi_ ("_",Var "a") (Var "b"))
$ lam ("x",Var "a")
$ App (Var "f") (Var "x")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment