Skip to content

Instantly share code, notes, and snippets.

@jozefg
Created August 3, 2017 13:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jozefg/92c73b378a71c0d6dc828b01b1d0d4a0 to your computer and use it in GitHub Desktop.
Save jozefg/92c73b378a71c0d6dc828b01b1d0d4a0 to your computer and use it in GitHub Desktop.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternGuards #-}
module Unification where
import Control.Monad
import Control.Monad.Gen
import Control.Monad.Trans
import qualified Data.Map.Strict as M
import Data.Foldable
import Data.List (foldl')
import Data.Monoid
import qualified Data.Set as S
type Id = Int
type Index = Int
data Term = FreeVar Id Term
| LocalVar Index
| MetaVar Id Term
| Constant Id Term
| Uni
| Ap Term Term
| Lam Term Term
| Pi Term Term
deriving (Eq, Show, Ord)
type Constraint = (Term, Term)
data Def = Axiom | Def Term
type Env = M.Map Id Def
raise :: Int -> Term -> Term
raise = go 0
where go lower i t = case t of
FreeVar i t -> FreeVar i t
LocalVar j -> if i >= lower then LocalVar (i + j) else LocalVar j
MetaVar i t -> MetaVar i t
Constant id tp -> Constant id tp
Uni -> Uni
Ap l r -> go lower i l `Ap` go lower i r
Lam tp body -> Lam (go lower i tp) (go (lower + 1) i body)
Pi tp body -> Pi (go lower i tp) (go (lower + 1) i body)
subst :: Term -> Int -> Term -> Term
subst new i t = case t of
FreeVar i t -> FreeVar i t
LocalVar j -> if i == j then new else LocalVar j
MetaVar i t -> MetaVar i t
Constant id tp -> Constant id tp
Uni -> Uni
Ap l r -> subst new i l `Ap` subst new i r
Lam tp body -> Lam (subst new i tp) (subst (raise 1 new) i body)
Pi tp body -> Pi (subst new i tp) (subst (raise 1 new) i body)
substMV :: Term -> Id -> Term -> Term
substMV new i t = case t of
FreeVar i t -> FreeVar i (substMV new i t)
LocalVar i -> LocalVar i
MetaVar j t -> if i == j then new else MetaVar j (substMV new i t)
Constant id tp -> Constant id tp
Uni -> Uni
Ap l r -> substMV new i l `Ap` substMV new i r
Lam tp body -> Lam (substMV new i tp) (substMV (raise 1 new) i body)
Pi tp body -> Pi (substMV new i tp) (substMV (raise 1 new) i body)
metavars :: Term -> S.Set Id
metavars t = case t of
FreeVar i t -> metavars t
LocalVar i -> S.empty
MetaVar j t -> S.insert j (metavars t)
Constant id tp -> S.empty
Uni -> S.empty
Ap l r -> metavars l <> metavars r
Lam tp body -> metavars tp <> metavars body
Pi tp body -> metavars tp <> metavars body
reduce :: Env -> Term -> Term
reduce env t = case t of
FreeVar i t -> FreeVar i t
LocalVar j -> LocalVar j
MetaVar i t -> MetaVar i t
Constant id tp -> case M.lookup id env of
Just (Def t) -> reduce env t
Just Axiom -> Constant id tp
Nothing -> Constant id tp
Uni -> Uni
Ap l r -> case reduce env l of
Lam tp body -> reduce env (subst r 0 body)
l' -> Ap l' r
Lam tp body -> Lam tp body
Pi tp body -> Pi tp body
type FreshM = GenT Id Maybe
typeOf :: Term -> FreshM (Term, S.Set Constraint)
typeOf t = do
(tp, cs) <- go [] t
return (tp, S.filter (uncurry (/=)) cs)
where go env t = case t of
FreeVar _ tp -> return (tp, S.empty)
LocalVar j -> do
guard (length env > j)
return (env !! j, S.empty)
MetaVar _ tp -> return (tp, S.empty)
Constant _ tp -> return (tp, S.empty)
Uni -> return $ (Uni, S.empty)
Ap l r -> go env l >>= \case
(Pi from to, cs1) -> do
(from', cs2) <- go env r
return (subst r 0 to, cs1 <> cs2 <> S.singleton (from, from'))
(fTp, cs1) -> do
(from', cs2) <- go env r
from <- MetaVar <$> gen <*> return Uni
to <- MetaVar <$> gen <*> return (Pi from Uni)
return (subst r 0 to,
cs1 <>
cs2 <>
S.fromList [(fTp, Pi from (Ap to (LocalVar 0))),
(from, from')])
Pi l r -> do
(tp1, cs1) <- go env l
(tp2, cs2) <- go (l : env) r
return (Uni, cs1 <> cs2 <> S.fromList [(tp1, Uni), (tp2, Uni)])
Lam tp body -> do
(tp1, cs1) <- go env tp
(to, cs2) <- go (tp : env) body
return (Pi tp to, cs1 <> cs2 <> S.singleton (tp1, Uni))
isRigid :: Term -> Bool
isRigid Constant {} = True
isRigid _ = False
isStuck :: Term -> Bool
isStuck MetaVar {} = True
isStuck (Ap f _) = isStuck f
isStuck _ = False
peelApTelescope :: Term -> (Term, [Term])
peelApTelescope t = go t []
where go (Ap f r) rest = go f (r : rest)
go t rest = (t, rest)
applyApTelescope :: Term -> [Term] -> Term
applyApTelescope = foldl' Ap
applyPiTelescope :: Term -> [Term] -> Term
applyPiTelescope retTp [] = retTp
applyPiTelescope retTp (argTp : rest) =
Pi argTp $ raise 1 (applyPiTelescope retTp rest)
assertPi :: Env -> Term -> FreshM (Term, Term, S.Set Constraint)
assertPi env t = do
(tp, cs) <- typeOf t
case reduce env t of
Pi l r -> return (l, r, cs)
t' -> case peelApTelescope t' of
(MetaVar stuckMVar tp, cxt) -> do
(cxtTps, css) <- unzip <$> mapM typeOf cxt
let fromMVarTp = applyPiTelescope Uni cxtTps
from <- MetaVar <$> gen <*> return fromMVarTp
let toMVarTp = applyPiTelescope Uni (cxtTps ++ [from])
to <- MetaVar <$> gen <*> return toMVarTp
let fromMVar = applyApTelescope from cxt
let toMVar = applyApTelescope to cxt
let cs' = S.singleton (Pi fromMVar (Ap (raise 1 toMVar) (LocalVar 0)), t')
return (fromMVar, Ap (raise 1 toMVar) (LocalVar 0), cs' <> fold css)
_ -> mzero
simplify :: Env -> Constraint -> FreshM (S.Set Constraint)
simplify env (t1, t2)
| t1 == t2 = return S.empty
| reduce env t1 /= t1 = simplify env (reduce env t1, t2)
| reduce env t2 /= t2 = simplify env (t1, reduce env t2)
| (FreeVar i tp, cxt) <- peelApTelescope t1,
(FreeVar j _, cxt') <- peelApTelescope t2,
i == j = do
guard (length cxt == length cxt')
fold <$> mapM (simplify env) (zip cxt cxt')
| (Constant i tp, cxt) <- peelApTelescope t1,
(Constant j _, cxt') <- peelApTelescope t2,
i == j = do
guard (length cxt == length cxt')
fold <$> mapM (simplify env) (zip cxt cxt')
| Lam tp1 body1 <- t1,
Lam tp2 body2 <- t2 = do
v <- FreeVar <$> gen <*> return tp1
return $ S.fromList
[(subst v 0 body1, subst v 0 body2),
(tp1, tp2)]
| Pi tp1 body1 <- t1,
Pi tp2 body2 <- t2 = do
v <- FreeVar <$> gen <*> return tp1
return $ S.fromList
[(subst v 0 body1, subst v 0 body2),
(tp1, tp2)]
-- | Lam tp body <- t1 = do
-- (from, to, cs) <- assertPi env t2
-- cs' <- simplify env (t1, Lam from (raise 1 t2 `Ap` LocalVar 0))
-- return (cs <> cs')
-- | Lam tp body <- t2 = do
-- (from, to, cs) <- assertPi env t1
-- cs' <- simplify env (Lam from (raise 1 t1 `Ap` LocalVar 0), t2)
-- return (cs <> cs')
| otherwise =
if isStuck t1 || isStuck t2 then return $ S.singleton (t1, t2) else mzero
type Subst = M.Map Id Term
manySubst :: Subst -> Term -> Term
manySubst s t = M.foldrWithKey (\mv sol t -> substMV sol mv t) t s
(<+>) :: Subst -> Subst -> Subst
s1 <+> s2 | not (M.null (M.intersection s1 s2)) = error "Impossible"
s1 <+> s2 = M.union s1 s2
tryFlexRigid :: Constraint -> FreshM ([FreshM [Subst]], S.Set Constraint)
tryFlexRigid (t1, t2)
| (MetaVar i tp, cxt1) <- peelApTelescope t1,
(stuckTerm, cxt2) <- peelApTelescope t2,
not (i `S.member` metavars t2) = do
(argTps, cs) <- fmap fold . unzip <$> mapM typeOf cxt1
let possibleSubsts = proj argTps i stuckTerm 0
return (possibleSubsts, cs)
| (MetaVar i tp, cxt1) <- peelApTelescope t2,
(stuckTerm, cxt2) <- peelApTelescope t1,
not (i `S.member` metavars t1) = do
(argTps, cs) <- fmap fold . unzip <$> mapM typeOf cxt1
let possibleSubsts = proj argTps i stuckTerm 0
return (possibleSubsts, cs)
| otherwise = mzero
where proj argTps mv f nargs =
generateSubst argTps mv f nargs : proj argTps mv f (nargs + 1)
generateSubst argTps mv f nargs = do
let mkLam tm = foldr Lam tm argTps
let saturateMV tm = foldl' Ap tm (map LocalVar [0..nargs - 1])
let mkSubst = M.singleton mv
mvs <-
map (uncurry MetaVar) . flip zip argTps <$> replicateM nargs gen
let args = map saturateMV mvs
return [mkSubst . mkLam $ applyApTelescope t args
| t <- map LocalVar [0..length argTps] ++ [f]]
repeatedlySimplify :: Env -> S.Set Constraint -> FreshM (S.Set Constraint)
repeatedlySimplify env cs = do
cs' <- fold <$> traverse (simplify env) (S.toList cs)
if cs' == cs then return cs else repeatedlySimplify env cs'
unify :: Subst -> Env -> S.Set Constraint -> FreshM (Subst, S.Set Constraint)
unify s env cs = do
let cs' = applySubst s cs
cs'' <- repeatedlySimplify env cs'
let (flexflexes, flexrigids) = S.partition flexflex cs''
if S.null flexrigids
then return (s, flexflexes)
else do
(psubsts, newC) <- tryFlexRigid (S.findMax flexrigids)
trySubsts psubsts (newC <> flexrigids <> flexflexes)
where applySubst s = S.map (\(t1, t2) -> (manySubst s t1, manySubst s t2))
flexflex (t1, t2) = isStuck t1 && isStuck t2
trySubsts [] cs = mzero
trySubsts (mss : psubsts) cs = do
ss <- mss
let tryThese =
foldr mplus mzero [unify (newS <+> s) env cs | newS <- ss]
let tryThose = trySubsts psubsts cs
tryThese `mplus` tryThose
driver :: Constraint -> Maybe (Subst, S.Set Constraint)
driver = runGenT . unify M.empty M.empty . S.singleton
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment