Skip to content

Instantly share code, notes, and snippets.

@Heimdell
Created October 19, 2022 19:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Heimdell/bf6fa1eae56c0b326f53d12313bfa332 to your computer and use it in GitHub Desktop.
Save Heimdell/bf6fa1eae56c0b326f53d12313bfa332 to your computer and use it in GitHub Desktop.
Unifier
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
module Bind where
import Control.Monad (when)
import Control.Monad.Except (MonadError(throwError))
import Control.Monad.State
import Data.Foldable (for_)
import Data.Map qualified as Map
import Data.Set qualified as Set
import Data.String (IsString(fromString))
import Var
import Unify
import Free
data TypeError n t e
= Mismatch (Term t n) (Term t n)
| Cycle (Var n) (Term t n)
| Other e
instance (Show e, forall x. Show x => Show (t x), Show n) => Show (TypeError n t e) where
show = \case
Mismatch fr fr' -> concat
[ "\nmismatch"
, " " <> show fr
, " =/="
, " " <> show fr'
]
Cycle var fr -> concat
[ "\ncycle"
, " " <> show var
, " ~"
, " " <> show fr
]
Other e -> "\n" <> show e
class
( HasVars n m
, MonadError (TypeError n t e) m
, Unifiable t
)
=>
Binds n t e m
| n m -> t
, t m -> n
, n t -> e
where
(=:) :: Var n -> Term t n -> m ()
see :: Var n -> m (Maybe (Term t n))
instance {-# OVERLAPPABLE #-}
( Binds n t e m
, MonadTrans mt
, MonadError (TypeError n t e) (mt m)
)
=>
Binds n t e (mt m)
where
(=:) = (lift .) . (=:)
see = lift . see
prune :: (Binds n t e m) => Term t n -> m (Term t n)
prune = \case
t@Free {} -> return t
Pure var -> go [var] var
where
go visited var = do
see var >>= \case
Just (Pure var') -> do
when (var' `elem` visited) do
error $ "vars form a cycle: " <> show visited
end <- go (var : visited) var'
var =: end
return end
_ -> return (Pure var)
(=:=) :: (Binds n t e m) => Term t n -> Term t n -> m ()
l0 =:= r0 = do
l <- prune l0
r <- prune r0
case (l, r) of
(Pure a, Pure b) -> do
t <- var (fromString "t")
a =: Pure t
b =: Pure t
(Pure a, b) -> assign a b
(a, Pure b) -> assign b a
(Free _ t, Free _ u) -> do
case coalesce t u of
Nothing -> throwError $ Mismatch l r
Just t' -> do
for_ t' \case
Left {} -> return ()
Right (l', r') -> l' =:= r'
where
assign var term
| occurs var term = throwError $ Cycle var term
| otherwise = var =: term
update :: (Binds n t e m) => Term t n -> m (Term t n)
update = \case
Free set t -> do
t' <- traverse update t
return (wrap t')
Pure var -> do
see var >>= \case
Nothing -> return (Pure var)
Just t -> do
update t
refreshVarNames :: forall n t e m. (Binds n t e m) => Term t n -> m (Term t n)
refreshVarNames t = evalStateT (traverseFree rename t) Map.empty
where
rename :: Var n -> StateT (Map.Map (Var n) (Var n)) m (Var n)
rename n = do
gets (Map.lookup n) >>= \case
Nothing -> do
n' <- fresh n
modify (Map.insert n n')
return n'
Just n' -> do
return n'
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE UndecidableInstances #-}
module Free where
import Data.Set (Set)
import Var (Var)
import qualified Data.Set as Set
import Data.Traversable (for)
data Free t v
= Free (Set v) (t (Free t v))
| Pure v
type Term t n = Free t (Var n)
wrap :: (Ord v, Foldable t) => t (Free t v) -> Free t v
wrap t = Free (foldMap allVars t) t
allVars :: (Ord v) => Free t v -> Set v
allVars = \case
Free set _ -> set
Pure v -> Set.singleton v
occurs :: (Ord v) => v -> Free t v -> Bool
occurs var term = var `Set.member` allVars term
traverseFree :: (Monad f, Traversable t, Ord n, Ord m) => (n -> f m) -> Free t n -> f (Free t m)
traverseFree f = \case
Free set t -> do
set' <- Set.fromList <$> traverse f (Set.toList set)
t' <- traverse (traverseFree f) t
return $ Free set' t'
Pure n -> do
Pure <$> f n
instance (forall x. Show x => Show (t x), Show v) => Show (Free t v) where
show = \case
Free set t -> show t
Pure v -> show v
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE QuantifiedConstraints #-}
module Scheme where
import Data.Set (Set)
import Bind
import Free
import Var
import Data.Map (Map)
import qualified Data.Set as Set
data Scheme t v = Scheme
{ schemeTypeArgs :: Set (Var v)
, schemeTypeBody :: Term t v
}
generalise :: (Binds n t e m) => Term t n -> m (Scheme t n)
generalise t = do
t' <- update t
return Scheme
{ schemeTypeArgs = allVars t'
, schemeTypeBody = t'
}
instantiate :: (Binds n t e m) => Scheme t n -> m (Term t n)
instantiate scheme = do
refreshVarNames (schemeTypeBody scheme)
instance (Show v, forall x. Show x => Show (t x)) => Show (Scheme t v) where
show scheme =
"forall "
<> unwords (map show (Set.toList (schemeTypeArgs scheme)))
<> ". "
<> show (schemeTypeBody scheme)
module Unify where
class Traversable t => Unifiable t where
coalesce :: t a -> t a -> Maybe (t (Either a (a, a)))
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
module Var where
import Control.Monad.Trans.Class (MonadTrans(..))
import Data.String (IsString)
data Var n = Var
{ varIndex :: Int
, varName :: n
}
deriving stock (Eq, Ord)
instance Show n => Show (Var n) where
show (Var 0 n) = show n
show (Var i n) = show n <> "#" <> show i
class (Show n, Ord n, IsString n, Monad m) => HasVars n m where
var :: n -> m (Var n)
fresh :: HasVars n m => Var n -> m (Var n)
fresh (Var _ n) = var n
instance {-# OVERLAPPABLE #-}
( HasVars n m
, MonadTrans t
, Monad (t m)
)
=>
HasVars n (t m)
where
var = lift . var
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment