Skip to content

Instantly share code, notes, and snippets.

@phadej
Created November 19, 2019 03:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phadej/cd9914823ac83291f6a693854bed3daf to your computer and use it in GitHub Desktop.
Save phadej/cd9914823ac83291f6a693854bed3daf to your computer and use it in GitHub Desktop.
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wall #-}
module Unification where
import Control.Monad (ap, forM, forM_)
import Data.Foldable (toList)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Functor.Classes
data UTerm t v
= UVar !v -- ^ A unification variable.
| UTerm !(t (UTerm t v)) -- ^ Some structure containing subterms.
deriving (Functor, Foldable, Traversable)
instance Functor t => Applicative (UTerm t) where
pure = UVar
(<*>) = ap
instance Functor t => Monad (UTerm t) where
return = UVar
UVar v >>= k = k v
UTerm t >>= k = UTerm (fmap (>>= k) t)
class Traversable t => ZipMatch t where
zipMatch :: t a -> t b -> Maybe (t (a, b))
-- | Equality constraint. Variable on the LHS, term on the RHS.
data Equality t v = Equality v (UTerm t v)
deriving (Functor)
data Error t v
= OccursCheck v (t (UTerm t v))
| CannotUnify (t (UTerm t v)) (t (UTerm t v))
-- | Because equality is skewed, we need a helper to unify two terms.
zipMatchRec :: ZipMatch t => t (UTerm t v) -> t (UTerm t v) -> Either (Error t v) [Equality t v]
zipMatchRec x y = case zipMatch x y of
Nothing -> Left (CannotUnify x y)
Just xs -> fmap concat $ traverse (uncurry zipMatchRec') $ toList xs
zipMatchRec' :: ZipMatch t => UTerm t v -> UTerm t v -> Either (Error t v) [Equality t v]
zipMatchRec' (UVar v) t = Right [Equality v t]
zipMatchRec' t (UVar v) = Right [Equality v t]
zipMatchRec' (UTerm x) (UTerm y) = zipMatchRec x y
solve :: forall t v. (ZipMatch t, Ord v) => [Equality t v] -> Either (Error t v) (Map v (UTerm t v))
-- No constraints: we are done.
solve [] = Right Map.empty
-- If there are something to do
solve (Equality v (UVar v') : eqs)
-- If RHS is same variable, skip the constraint
| v == v' = solve eqs
-- If different, replace v with v', and add that equality to the solution.
| otherwise = do
let subst :: v -> v
subst u | u == v = v'
| otherwise = u
-- get a solution with replaced variables
solution <- solve (map (fmap subst) eqs)
-- lookup if v' has a solution, share it if it's there
case Map.lookup v' solution of
Just t' -> return (Map.insert v t' solution)
Nothing -> return (Map.insert v (UVar v') solution)
-- If the RHS is a term then...
solve (Equality v ut@(UTerm t) : eqs)
-- perform occurs check
| v `elem` ut = Left (OccursCheck v t)
-- otherwise we substitute this equality into all other constraints
| otherwise = do
let subst :: v -> UTerm t v
subst v' | v == v' = ut
| otherwise = UVar v'
eqs' <- forM eqs $ \eq@(Equality v' ut') ->
-- if LHS is is different variable, substitute in RHS
if v /= v'
then return [ Equality v' (ut' >>= subst) ]
-- otherwise we try to unify, or zipMatch
else case ut' of
UVar u
| u == v -> return [ Equality v' ut ]
| otherwise -> return [ eq ]
UTerm t' -> zipMatchRec t t'
solution <- solve (concat eqs')
-- this substition flattens the terms
let subst' :: v -> UTerm t v
subst' u = case Map.lookup u solution of
Just new -> new
Nothing -> UVar u -- e.g. here we could perform defaulting.
return $ Map.insert v (ut >>= subst') solution
-------------------------------------------------------------------------------
-- Example
-------------------------------------------------------------------------------
data Ty a
= TyBool
| TyArrow a a
deriving (Functor, Foldable, Traversable, Show)
instance ZipMatch Ty where
zipMatch TyBool TyBool = Just TyBool
zipMatch (TyArrow a b) (TyArrow c d) = Just (TyArrow (a, c) (b, d))
zipMatch _ _ = Nothing
-- |
--
-- >>> printSolution (solve example)
-- 'a' -> UTerm TyBool
-- 'b' -> UTerm TyBool
-- 'x' -> UTerm (TyArrow (UTerm TyBool) (UTerm TyBool))
--
-- >>> solve [ Equality 'x' $ UTerm $ TyArrow (UVar 'x') (UVar 'x') ]
-- Left (OccursCheck 'x' (TyArrow (UVar 'x') (UVar 'x')))
--
-- >>> solve [ Equality 'x' $ UTerm $ TyArrow (UVar 'y') (UVar 'z'), Equality 'x' $ UTerm TyBool ]
-- Left (CannotUnify (TyArrow (UVar 'y') (UVar 'z')) TyBool)
--
example :: [Equality Ty Char]
example =
[ Equality 'x' $ UTerm $ TyArrow (UTerm TyBool) (UVar 'a')
, Equality 'x' $ UTerm $ TyArrow (UVar 'b') (UVar 'b')
]
-------------------------------------------------------------------------------
-- Show1
-------------------------------------------------------------------------------
printSolution :: (Foldable t, Show a, Show b) => t (Map a b) -> IO ()
printSolution solution =
forM_ solution $ \s ->
forM (Map.toList s) $ \(v, t) ->
putStrLn $ show v ++ " -> " ++ show t
instance Show1 Ty where
liftShowsPrec _ _ _ TyBool = showString "TyBool"
liftShowsPrec sp _ d (TyArrow a b) = showsBinaryWith sp sp "TyArrow" d a b
instance (Show v, Show1 t) => Show (UTerm t v) where
showsPrec d (UVar v) = showsUnaryWith showsPrec "UVar" d v
showsPrec d (UTerm t) = showsUnaryWith (liftShowsPrec showsPrec showList) "UTerm" d t
instance (Show v, Show1 t) => Show (Error t v) where
showsPrec d (OccursCheck v t) = showsBinaryWith showsPrec (liftShowsPrec showsPrec showList) "OccursCheck" d v t
showsPrec d (CannotUnify x y) = showsBinaryWith (liftShowsPrec showsPrec showList) (liftShowsPrec showsPrec showList) "CannotUnify" d x y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment