Created November 22, 2023 12:11
Pattern unification
module CatPu where
import Control.Applicative (Alternative (..))
import Control.Monad (MonadPlus (..), foldM, forM)
import Control.Monad.State
import Control.Monad.Except
import Data.List (intersect, elemIndex)
type Goal = SolverState -> Stream SolverState
data Tm = Lam Tm
| Var Int [Tm]
| Meta Int Ren
deriving (Show, Eq)
type Ren = [Int]
data Ty = Iot [Ty] deriving (Show, Eq)
type Ctx = [Ty]
type MCtx = [(Int,Ty)]
type Subs = [(Int,Tm)]
type SolverState = (MCtx, Subs, Int)
data Blocker
= Occurs
| InversionFailed
| Different
deriving (Show)
newtype Solver a = Solver (StateT SolverState (Except Blocker) a)
runSolver :: SolverState -> Solver a -> Either Blocker (a, SolverState)
runSolver st (Solver m) = runExcept (runStateT m st)
execSolver :: SolverState -> Solver a -> Either Blocker SolverState
execSolver st (Solver m) = runExcept (execStateT m st)
deriving instance Applicative Solver
deriving instance Functor Solver
deriving instance Monad Solver
deriving instance MonadError Blocker Solver
deriving instance MonadState SolverState Solver
iot = Iot []
walk :: Tm -> SolverState -> Tm
walk u (_, s, _) = go u
where go :: Tm -> Tm
go (Lam u) = Lam (go u)
go (Var x e) = Var x (fmap go e)
go (Meta i e) | Just z <- lookup i s = go (mapp z e)
go (Meta i e) = Meta i e
mapp :: Tm -> [Int] -> Tm
mapp u [] = u
mapp (Lam u) (x:xs) = mapp (ren 0 x u) xs
where ren :: Int -> Int -> Tm -> Tm
ren x y (Lam u) = Lam (ren (x+1) (y+1) u)
ren x y (Var i e) | (i == x) = Var y (fmap (ren x y) e)
ren x y (Var i e) = Var i (fmap (ren x y) e)
ren x y (Meta k p) = Meta k (fmap (\i -> if i == x then y else i) p)
(===) :: Tm -> Tm -> Goal
(===) t1 t2 st = case execSolver st (unify (walk t1 st) (walk t2 st)) of
Left blocker -> Nil
Right st' -> pure st'
fresh' :: Ty -> SolverState -> (Int, SolverState)
fresh' ty (m,s,i) = (i, ((i,ty):m, s, i+1))
mlookup :: Int -> SolverState -> Ty
mlookup k (m,s,i) = let Just ty = lookup k m in ty
extS :: Int -> Tm -> Solver ()
extS k u = do
(m,s,i) <- get
put (m,(k,u):s,i)
lams :: [Ty] -> Tm -> Tm
lams [] u = u
lams (x:xs) u = Lam (lams xs u)
unify :: Tm -> Tm -> Solver ()
unify (Lam t) (Lam u) = unify t u
unify (Var x e) (Var y e') | (x == y) && (length e == length e') = do
forM_ (zip e e') $ \(a, b) -> do
a' <- gets (walk a)
b' <- gets (walk b)
unify a' b'
unify (Meta k p) (Meta k' p') | (k == k') = do
Iot tys <- gets (mlookup k)
-- discard renamings that don't match.
let mi = length tys - 1
let q = filter (\(_,_,k) -> k) $ zip3 (reverse [0..mi]) tys (fmap (\(x,y) -> x == y) (zip p p'))
let ty = Iot (fmap (\(_,ty,_) -> ty) q)
let vec = fmap (\(i,ty,_) -> i) q
-- introduce new meta that renames.
j <- state (fresh' ty)
extS k (lams tys (Meta j vec))
unify (Meta k p) (Meta k' p') = do
-- identify set of variables visible in both renamings.
Iot tys <- gets (mlookup k)
let sect = intersect p p'
let tyvec i | Just q <- elemIndex i p = tys!!q
let vec p i | Just q <- elemIndex i p = length p - q - 1
let ty = Iot (fmap tyvec sect)
-- introduce new meta that renames.
j <- state (fresh' ty)
extS k (lams tys (Meta j (fmap (vec p) sect)))
extS k' (lams tys (Meta j (fmap (vec p') sect)))
unify (Meta k p) e = assign k p e
unify e (Meta k p) = assign k p e
unify _ _ = throwError Different
assign :: Int -> Ren -> Tm -> Solver ()
assign k p u = do
Iot tys <- gets (mlookup k)
when (occurs k u) (throwError Occurs)
-- invert 'p' and check that u@(Var i e) contains only variables
-- defined in 'p'. Replace and assign.
let m = zip p (reverse [0..length p - 1])
u' <- replace m 0 u
extS k (lams tys u')
replace_var :: [(Int,Int)] -> Int -> Int -> Solver Int
replace_var m d i | i < d = pure i
| otherwise = case lookup (i-d) m of
Nothing -> throwError InversionFailed
Just k -> pure (k + d)
replace :: [(Int,Int)] -> Int -> Tm -> Solver Tm
replace m d (Lam u) = do
fmap Lam (replace m (d+1) u)
replace m d (Var x e) = do
y <- replace_var m d x
e' <- forM e (replace m d)
pure (Var y e')
replace m d (Meta k p) = do
catchError (do p' <- forM p (replace_var m d)
pure (Meta k p'))
(\_ -> do prune ([0..d-1] <> fmap ((+d) . fst) m) k p
u <- gets (walk (Meta k p))
replace m d u)
-- prune away innard variables (by assigning new meta variables)
prune :: [Int] -> Int -> Ren -> Solver ()
prune m k p = do
Iot tys <- gets (mlookup k)
let sect = intersect m p
tyvec i | Just q <- elemIndex i p = tys!!q
vec p i | Just q <- elemIndex i p = length p - q - 1
ty = Iot (fmap tyvec sect)
-- introduce new meta that renames.
j <- state (fresh' ty)
extS k (lams tys (Meta j (fmap (vec p) sect)))
occurs :: Int -> Tm -> Bool
occurs i (Meta j p) = (i == j)
occurs i (Var j e) = foldl (||) False (fmap (occurs i) e)
occurs i (Lam u) = occurs i u
fresh :: Ty -> ((Ren -> Tm) -> Goal) -> Goal
fresh ty f st
= let (c, st') = fresh' ty st
in f (Meta c) st'
disj :: Goal -> Goal -> Goal
disj g1 g2 st = g1 st `mplus` g2 st
conj :: Goal -> Goal -> Goal
conj g1 g2 st = g1 st >>= g2
data Stream a = Nil
| Cons a (Stream a)
| Delayed (Stream a)
deriving (Eq, Show)
instance Monad Stream where
Nil >>= _ = Nil
x `Cons` xs >>= f = f x `mplus` (xs >>= f)
Delayed s >>= f = Delayed (s >>= f)
instance MonadPlus Stream where
mzero = empty
mplus = (<|>)
instance Alternative Stream where
empty = Nil
Nil <|> xs = xs
(x `Cons` xs) <|> ys = x `Cons` (ys <|> xs)
Delayed xs <|> ys = Delayed (ys <|> xs)
instance Functor Stream where
fmap _ Nil = Nil
fmap f (a `Cons` s) = f a `Cons` fmap f s
fmap f (Delayed s) = Delayed (fmap f s)
instance Applicative Stream where
pure a = a `Cons` Nil
Nil <*> _ = Nil
_ <*> Nil = Nil
(f `Cons` fs) <*> as = fmap f as <|> (fs <*> as)
Delayed fs <*> as = Delayed (fs <*> as)
failure :: Goal
failure _ = Nil
delay :: Goal -> Goal
delay = fmap Delayed
initialState :: SolverState
initialState = ([],[],0)
takeS :: Int -> Stream a -> [a]
takeS 0 _ = []
takeS n Nil = []
takeS n (Delayed s) = takeS n s
takeS n (a `Cons` as) = a : takeS (n - 1) as
