Skip to content

Instantly share code, notes, and snippets.

@AndrasKovacs
Last active June 14, 2021 07:23
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AndrasKovacs/9c899ea04e6a24be51911f44b2691f7c to your computer and use it in GitHub Desktop.
Save AndrasKovacs/9c899ea04e6a24be51911f44b2691f7c to your computer and use it in GitHub Desktop.
simple universe polymorphism
{-# language LambdaCase, Strict, BangPatterns, ViewPatterns, OverloadedStrings #-}
{-# options_ghc -Wincomplete-patterns #-}
module UnivPoly where
import Data.Foldable
import Data.Maybe
import Data.String
import Debug.Trace
--------------------------------------------------------------------------------
type Name = String
type Clos = Val -> Val
type Env = [Val]
type Ty = Tm
type RTy = RTm
type VTy = Val
type Ix = Int
type Lvl = Int
data RTm
= RVar Name
| RApp RTm RTm
| RLam Name RTm
| RPi Name RTm RTm
| RLet Name RTy RTm RTm
| RUFin RTm
| RFinLvl
| RL0
| RLS RTm
| RLMax RTm RTm
deriving Show
data Level
= Fin Tm
| Omega
deriving Show
data Tm
= Var Ix
| App Tm Tm
| Lam Name Tm
| Pi Name Tm Tm
| Let Name Ty Tm Tm
| U Level
| FinLvl
| L0
| LS Tm
| LMax Tm Tm
deriving Show
data VLevel
= VFin Val
| VOmega
data Val
= VVar Lvl
| VApp Val ~Val
| VLam Name Clos
| VPi Name VTy Clos
| VU VLevel
| VFinLvl
| VL0
| VLS Val
| VLMax Val Val
finmax :: Val -> Val -> Val
finmax l1 l2 = case (l1, l2) of
(VL0, l2 ) -> l2
(l1, VL0 ) -> l1
(VLS l1, VLS l2) -> VLS (finmax l1 l2)
(l1, l2 ) -> VLMax l1 l2
level :: Env -> Level -> VLevel
level e = \case
Omega -> VOmega
Fin t -> VFin (eval e t)
eval :: Env -> Tm -> Val
eval e = \case
Var x -> e !! x
App t u -> case (eval e t, eval e u) of
(VLam _ t, u) -> t u
(t, u) -> VApp t u
Lam x t -> VLam x (\u -> eval (u:e) t)
Pi x a b -> VPi x (eval e a) (\u -> eval (u:e) b)
Let x a t u -> eval (eval e t:e) u
U t -> VU (level e t)
FinLvl -> VFinLvl
L0 -> VL0
LS t -> VLS (eval e t)
LMax t u -> finmax (eval e t) (eval e u)
quoteLevel :: Lvl -> VLevel -> Level
quoteLevel l = \case
VFin t -> Fin (quote l t)
VOmega -> Omega
quote :: Lvl -> Val -> Tm
quote l = \case
VVar x -> Var (l - x - 1)
VApp t u -> App (quote l t) (quote l u)
VLam x t -> Lam x (quote (l + 1) (t (VVar l)))
VPi x a b -> Pi x (quote l a) (quote (l + 1) (b (VVar l)))
VL0 -> L0
VLS t -> LS (quote l t)
VLMax t u -> LMax (quote l t) (quote l u)
VFinLvl -> FinLvl
VU t -> U (quoteLevel l t)
nf :: Tm -> Tm
nf = quote 0 . eval []
conv :: Lvl -> Val -> Val -> Bool
conv l t t' = case (t, t') of
(VVar x , VVar x' ) -> x == x'
(VApp t u , VApp t' u' ) -> conv l t t' && conv l u u'
(VLam x t , VLam _ t' ) -> conv ((l + 1)) (t (VVar l)) (t' (VVar l))
(VLam x t , t' ) -> conv ((l + 1)) (t (VVar l)) (VApp t' (VVar l))
(t , VLam x t' ) -> conv ((l + 1)) (VApp t (VVar l)) (t' (VVar l))
(VPi x a b , VPi _ a' b' ) -> conv l a a' && conv ((l + 1)) (b (VVar l)) (b' (VVar l))
(VL0 , VL0 ) -> True
(VLS t , VLS t' ) -> conv l t t'
(VLMax t u , VLMax t' u' ) -> conv l t t' && conv l u u'
(VU VOmega , VU VOmega ) -> True
(VU (VFin t) , VU (VFin t')) -> conv l t t'
(VFinLvl , VFinLvl ) -> True
_ -> False
data Cxt = Cxt {env :: Env, types :: [(Name, (Lvl, VTy))], lvl :: Lvl}
define :: Name -> VTy -> Val -> Cxt -> Cxt
define x a ~t (Cxt e ts l) = Cxt (t:e) ((x, (l, a)):ts) (l + 1)
bind :: Name -> VTy -> Cxt -> Cxt
bind x a (Cxt e ts l) = Cxt (VVar l:e) ((x, (l, a)):ts) (l + 1)
check :: Cxt -> RTm -> VTy -> Either String Tm
check cxt t a = case (t, a) of
(RLam x t, VPi x' a b) -> do
Lam x <$> check (bind x a cxt) t (b (VVar (lvl cxt)))
(RLet x a t u, b) -> do
(a, _) <- checkTy cxt a
let ~va = eval (env cxt) a
u <- check cxt t va
t <- check (define x va (eval (env cxt) u) cxt) t b
pure (Let x a t u)
(t, a) -> do
(t, a') <- infer cxt t
if conv (lvl cxt) a a'
then pure t
else Left $ "Type mismatch, expected\n\n "
++ show (quote (lvl cxt) a)
++ "\n\ninferred\n\n " ++ show (quote (lvl cxt) a')
checkTy :: Cxt -> RTm -> Either String (Tm, VLevel)
checkTy cxt t = do
(t, a) <- infer cxt t
case a of
VU l -> pure (t, l)
_ -> Left "expected a type"
lmax :: VLevel -> VLevel -> VLevel
lmax (VFin t) (VFin t') = VFin (finmax t t')
lmax _ _ = VOmega
-- strengthening over a bound var
strLevel :: Lvl -> Lvl -> VLevel -> Either String Level
strLevel l x = \case
VOmega -> pure Omega
VFin t -> Fin <$> str l x t
str :: Lvl -> Lvl -> Val -> Either String Tm
str l x = \case
VVar x' -> case compare x' x of
EQ -> Left "illegal universe level dependency"
LT -> pure (Var (l - x' - 1))
GT -> pure (Var (l - x'))
VApp t u -> App <$> str l x t <*> str l x u
VLam x' t -> Lam x' <$> str (l + 1) x (t (VVar l))
VPi x' a b -> Pi x' <$> str l x a <*> str (l + 1) x (b (VVar l))
VU t -> U <$> strLevel l x t
VFinLvl -> pure FinLvl
VL0 -> pure L0
VLS t -> LS <$> str l x t
VLMax t u -> LMax <$> str l x t <*> str l x u
infer :: Cxt -> RTm -> Either String (Tm, VTy)
infer cxt = \case
RVar x -> case lookup x (types cxt) of
Nothing -> Left $ "Name not in scope: " ++ x
Just (l, a) -> pure (Var (lvl cxt - l - 1), a)
RApp t u -> do
(t, a) <- infer cxt t
case a of
VPi x a b -> do
u <- check cxt u a
pure (App t u, b (eval (env cxt) u))
_ ->
Left "expected a function"
RLam{} ->
Left "can't infer type for lambda"
topt@(RPi x a b) -> do
(a, al) <- checkTy cxt a
case al of
VOmega -> do
(b, bl) <- checkTy (bind x (eval (env cxt) a) cxt) b
pure (Pi x a b, VU VOmega)
VFin al -> do
(b, bl) <- checkTy (bind x (eval (env cxt) a) cxt) b
bl <- strLevel (lvl cxt) (lvl cxt) bl
let newl = lmax (VFin al) (level (env cxt) bl)
pure (Pi x a b, VU newl)
RLet x a t u -> do
(a, al) <- checkTy cxt a
let ~va = eval (env cxt) a
t <- check cxt t va
(u, b) <- infer (define x va (eval (env cxt) t) cxt) u
pure (Let x a t u, b)
RUFin t -> do
t <- check cxt t VFinLvl
pure (U (Fin t), VU (VFin (VLS (eval (env cxt) t))))
RFinLvl -> do
pure (FinLvl, VU VOmega)
RL0 -> do
pure (L0, VFinLvl)
RLS t -> do
t <- check cxt t VFinLvl
pure (LS t, VFinLvl)
RLMax t u -> do
t <- check cxt t VFinLvl
u <- check cxt u VFinLvl
pure (LMax t u, VFinLvl)
elab :: RTm -> IO () -- Either String (Tm, Tm, Ty)
elab t = do
case infer (Cxt [] [] 0) t of
Left err -> putStrLn err
Right (t, a) -> do
putStrLn "---- term ----"
print t
putStrLn "---- nf ----"
print $ nf t
putStrLn "---- type ----"
print $ quote 0 a
--------------------------------------------------------------------------------
instance IsString RTm where fromString = RVar
infixl 7 $$
($$) = RApp
(==>) a b = RPi "_" a b
infixr 4 ==>
--------------------------------------------------------------------------------
-- (f : U0 -> FinLvl) -> (A : U0) -> U (f A) -> U0
illegal :: RTm
illegal =
RPi "f" (RUFin RL0 ==> RFinLvl) $ RPi "A" (RUFin RL0) $ RUFin ("f" $$ "A") ==> RUFin RL0
-- elab p1
p1 :: RTm
p1 =
RLet "id" (RPi "l" RFinLvl $ RPi "A" (RUFin "l") $ "A" ==> "A")
(RLam "l" $ RLam "A" $ RLam "x" "x") $
RLet "foo" ((RPi "l" RFinLvl $ RPi "A" (RUFin "l") $ "A" ==> "A") ==> RFinLvl)
(RLam "f" $ RL0) $
RLet "bar" (RUFin RL0 ==> RFinLvl) (RLam "A" RL0) $
"id" $$ RLS RL0 $$ (RPi "A" (RUFin RL0) $ "A" ==> "A") $$ (RLam "A" $ RLam "x" "x")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment