Skip to content

Instantly share code, notes, and snippets.

@Heimdell
Created July 16, 2022 21:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Heimdell/1fbad9c325253928c9ca8c6342847083 to your computer and use it in GitHub Desktop.
Save Heimdell/1fbad9c325253928c9ca8c6342847083 to your computer and use it in GitHub Desktop.
Unification ex-di
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
module Fixpoint where
import Control.Monad ((<=<))
import Data.Data (Data, Typeable)
import GHC.Generics (Generic)
import Data.Foldable (fold)
-- | Fixpoint of a functor f, with some context c at each node.
--
newtype Fix f = Fix
{ unFix :: f (Fix f)
}
deriving stock (Generic)
deriving stock instance (Show (f (Fix f))) => Show (Fix f)
deriving stock instance (Eq (f (Fix f))) => Eq (Fix f)
deriving stock instance (Ord (f (Fix f))) => Ord (Fix f)
deriving stock instance (Data (f (Fix f)), Typeable f) => Data (Fix f)
-- | Eliminator for `Fix`.
--
cataFix :: (Functor f) => (f a -> a) -> Fix f -> a
cataFix alg = alg . fmap (cataFix alg) . unFix
-- | Eliminator for `Fix`, monadic.
--
cataFixM :: (Traversable f, Monad m) => (f a -> m a) -> Fix f -> m a
cataFixM alg = do
alg <=< traverse (cataFixM alg) . unFix
-- | A variant of `Fix` with some nodes replaced by @a@.
--
-- It is a free monad, yes.
--
data Term f a
= Node (f (Term f a))
| Leaf a
deriving stock (Generic, Functor, Foldable, Traversable)
newtype Unshow = Unshow { unShow :: String }
instance Show Unshow where
show = unShow
instance (Show a, Show (f Unshow), Functor f) => Show (Term f a) where
show = unShow . cataTerm (Unshow . show) (Unshow . show)
deriving stock instance (Eq a, Eq (f (Term f a))) => Eq (Term f a)
deriving stock instance (Ord a, Ord (f (Term f a))) => Ord (Term f a)
deriving stock instance (Data a, Data (f (Term f a)), Typeable f) => Data (Term f a)
cataTerm :: (Functor f) => (f b -> b) -> (a -> b) -> Term f a -> b
cataTerm node leaf = \case
Node layer -> node (fmap (cataTerm node leaf) layer)
Leaf a -> leaf a
cataTermM :: (Traversable f, Monad m) => (f b -> m b) -> (a -> m b) -> Term f a -> m b
cataTermM node leaf = \case
Node layer -> node =<< traverse (cataTermM node leaf) layer
Leaf a -> leaf a
unfreeze :: (Functor f) => Fix f -> Term f a
unfreeze = cataFix Node
freeze :: (Traversable f) => Term f a -> Maybe (Fix f)
freeze = cataTermM (Just . Fix) (const Nothing)
cabal-version: 2.4
name: unification-xd
version: 0.1.0.0
-- A short (one-line) description of the package.
-- synopsis:
-- A longer description of the package.
-- description:
-- A URL where users can report bugs.
-- bug-reports:
-- The license under which the package is released.
-- license:
author: Kirill Andreev
maintainer: Kirill.Andreev@kaspersky.com
-- A copyright notice.
-- copyright:
-- category:
extra-source-files: CHANGELOG.md
library
hs-source-dirs: src
build-depends: base, mtl, transformers, containers, microlens-platform, shower
hs-source-dirs: app
default-language: Haskell2010
exposed-modules: Unification, Fixpoint
default-extensions:
LambdaCase
BlockArguments
DerivingStrategies
StandaloneDeriving
UndecidableInstances
DeriveDataTypeable
DeriveGeneric
MultiParamTypeClasses
FunctionalDependencies
TypeOperators
DefaultSignatures
FlexibleContexts
FlexibleInstances
ImportQualifiedPost
GeneralizedNewtypeDeriving
TemplateHaskell
ExplicitForAll
TypeApplications
DeriveAnyClass
DeriveFunctor
DeriveFoldable
DeriveTraversable
ScopedTypeVariables
executable unification-xd
main-is: Main.hs
-- Modules included in this executable, other than Main.
-- other-modules:
-- LANGUAGE extensions used by modules in this package.
-- other-extensions:
build-depends: base ^>=4.14.3.0
hs-source-dirs: app
default-language: Haskell2010
module Unification where
import Control.Monad.Writer
import Control.Monad.Except
import Control.Monad.State
import Control.Monad.Identity
import Data.IntMap qualified as IntMap
import Data.IntMap (IntMap)
import Data.IntSet qualified as IntSet
import Data.IntSet (IntSet)
import Data.Functor.Compose
import Data.Foldable (fold)
import Data.Traversable (for)
import GHC.Generics
import Lens.Micro.Platform
import Shower
import Fixpoint
class Traversable f => Unifiable f where
match :: f a -> f a -> Maybe (f (Either a (a, a)))
default
match :: (Generic1 f, Unifiable (Rep1 f)) => f a -> f a -> Maybe (f (Either a (a, a)))
match a b = to1 <$> match (from1 a) (from1 b)
instance Unifiable V1 where
match a _ = return $ Left <$> a
instance Unifiable U1 where
match a _ = return $ Left <$> a
instance Unifiable Par1 where
match (Par1 a) (Par1 b) = return $ Par1 $ Right (a, b)
instance (Unifiable f) => Unifiable (Rec1 f) where
match (Rec1 a) (Rec1 b) = Rec1 <$> match a b
instance (Eq c) => Unifiable (K1 i c) where
match (K1 a) (K1 b)
| a == b = return (K1 a)
| otherwise = Nothing
instance (Unifiable f) => Unifiable (M1 i c f) where
match (M1 a) (M1 b) = M1 <$> match a b
instance (Unifiable f, Unifiable g) => Unifiable (f :+: g) where
match a b = case (a, b) of
(L1 q, L1 w) -> L1 <$> match q w
(R1 q, R1 w) -> R1 <$> match q w
_ -> Nothing
instance (Unifiable f, Unifiable g) => Unifiable (f :*: g) where
match (a :*: c) (b :*: d) = pure (:*:) <*> match a b <*> match c d
instance (Unifiable f, Unifiable g) => Unifiable (f :.: g) where
match (Comp1 a) (Comp1 b) = do
res' <- match a b
res'' <- for res' \case
Left ga -> return $ Left <$> ga
Right (ga, gb) -> match ga gb
return $ Comp1 res''
class Variable v where
getVarId :: v -> Int
makeVar :: Int -> v
class
( Variable v
, Unifiable t
, Monad m
)
=> BindingMonad m t v
| m t -> v
, m v -> t
where
find :: v -> m (Maybe (Term t v))
fresh :: m v
new :: Term t v -> m v
(=:) :: v -> Term t v -> m (Term t v)
new t = do
v <- fresh
v =: t
return v
instance {-# OVERLAPPABLE #-}
( BindingMonad m t v
, MonadTrans h
, Monad (h m)
) => BindingMonad (h m) t v
where
find = lift . find
fresh = lift fresh
(=:) = (lift .) . (=:)
data UnifState t v = UnifState
{ _usMap :: IntMap (Term t v)
, _usCounter :: Int
}
deriving stock instance (Show a, Show (f Unshow), Functor f) => Show (UnifState f a)
makeLenses ''UnifState
startUnifState :: UnifState t v
startUnifState = UnifState mempty 0
data UnificationError t v
= Occurs v (Term t v)
| Mismatch (Term t v) (Term t v)
deriving stock instance (Show a, Show (f Unshow), Functor f) => Show (UnificationError f a)
type Unification t v = UnificationT t v Identity
newtype UnificationT t v m a = UnificationT
{ runUnificationT
:: StateT (UnifState t v)
( ExceptT (UnificationError t v)
m ) a
}
deriving newtype
( Functor
, Applicative
, Monad
, MonadError (UnificationError t v)
)
instance MonadTrans (UnificationT t v) where
lift = UnificationT . lift . lift
runUnification :: forall t v a. Unification t v a -> Either (UnificationError t v) (a, UnifState t v)
runUnification unif
= runIdentity
$ runExceptT
$ flip runStateT startUnifState
$ runUnificationT unif
instance
(Variable v, Unifiable t, Monad m)
=> BindingMonad (UnificationT t v m) t v
where
find v = UnificationT do IntMap.lookup (getVarId v) <$> use usMap
fresh = do
vId <- UnificationT do use usCounter
UnificationT do usCounter += 1
return (makeVar vId)
v =: t = do
UnificationT do usMap %= IntMap.insert (getVarId v) t
return t
prune :: BindingMonad m t v => Term t v -> m (Term t v)
prune = \case
t@Node{} -> return t
Leaf v -> do
find v >>= \case
Just v' -> do
t <- prune v'
v =: t
Nothing -> return (Leaf v)
semiprune :: BindingMonad m t v => Term t v -> m (Term t v)
semiprune t = case t of
Node{} -> return t
Leaf v -> loop v t
where
loop v0 t0 = do
find v0 >>= \case
Nothing -> return t0
Just Node{} -> return t0
Just t@(Leaf v) -> do
final <- loop v t
v0 =: final
newtype OccursCheckT t v m a = OccursCheckT
{ runOccursCheckT :: StateT (IntMap (t (Term t v))) m a
}
deriving newtype
( Functor
, Applicative
, Monad
, MonadError e
)
instance MonadTrans (OccursCheckT t v) where
lift = OccursCheckT . lift
class (Monad m) => MonadOccurs m t v where
seenAs :: v -> t (Term t v) -> m ()
instance (MonadError (UnificationError t v) m, Variable v) => MonadOccurs (OccursCheckT t v m) t v where
seenAs v t = do
vars <- OccursCheckT get
case IntMap.lookup (getVarId v) vars of
Just t -> throwError $ Occurs v $ Node t
Nothing -> OccursCheckT $ modify $ IntMap.insert (getVarId v) t
getFreeVars
:: forall m t v list
. (BindingMonad m t v, Traversable list)
=> list (Term t v) -> m [v]
getFreeVars list = do
idSet <- evalStateT (fold <$> traverse loop list) IntSet.empty
return $ map makeVar $ IntSet.toList idSet
where
varsOf :: Foldable f => f v -> [Int]
varsOf = IntSet.toList . foldMap (IntSet.singleton . getVarId)
loop :: Term t v -> StateT IntSet m IntSet
loop t = do
semiprune t >>= \case
Node t -> fold <$> traverse loop t
Leaf v -> do
let vId = getVarId v
done <- gets (IntSet.member vId)
if done then return mempty
else do
modify (IntSet.insert vId)
find v >>= \case
Nothing -> return $ IntSet.singleton $ getVarId vId
Just t -> loop t
-- applyBindings
-- :: forall m t v list
-- . (BindingMonad m t v, Traversable list)
-- => list (Term t v) -> m (list (Term t v))
-- applyBindings list = do
-- evalStateT () IntMap.empty
--------------------------------------------------------------------------------
data Type self
= TArr self self
| TCon String
| TSet self
| TMap self self
| TRec [(String, self)]
deriving stock (Generic1, Functor, Foldable, Traversable, Show)
deriving anyclass (Unifiable)
deriving anyclass instance (Eq a) => Unifiable ((,) a)
deriving anyclass instance Unifiable []
instance Variable Int where
getVarId = id
makeVar = id
deriving anyclass instance Unifiable Maybe
test = printer do
runUnification @Type @Int do
a <- fresh
b <- fresh
c <- new $ Node $ TCon "Int"
a =: Leaf b
b =: Leaf c
semiprune (Leaf a)
test1 = printer do
runUnification @Type @Int do
a <- fresh
b <- fresh
c <- new $ Node $ TCon "Int"
a =: Leaf b
b =: Leaf c
getFreeVars [Node $ TArr (Leaf a) (Leaf 42)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment