Created
November 19, 2019 03:37
-
-
Save phadej/cd9914823ac83291f6a693854bed3daf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveTraversable #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# OPTIONS_GHC -Wall #-} | |
module Unification where | |
import Control.Monad (ap, forM, forM_) | |
import Data.Foldable (toList) | |
import Data.Map.Strict (Map) | |
import qualified Data.Map.Strict as Map | |
import Data.Functor.Classes | |
data UTerm t v | |
= UVar !v -- ^ A unification variable. | |
| UTerm !(t (UTerm t v)) -- ^ Some structure containing subterms. | |
deriving (Functor, Foldable, Traversable) | |
instance Functor t => Applicative (UTerm t) where | |
pure = UVar | |
(<*>) = ap | |
instance Functor t => Monad (UTerm t) where | |
return = UVar | |
UVar v >>= k = k v | |
UTerm t >>= k = UTerm (fmap (>>= k) t) | |
class Traversable t => ZipMatch t where | |
zipMatch :: t a -> t b -> Maybe (t (a, b)) | |
-- | Equality constraint. Variable on the LHS, term on the RHS. | |
data Equality t v = Equality v (UTerm t v) | |
deriving (Functor) | |
data Error t v | |
= OccursCheck v (t (UTerm t v)) | |
| CannotUnify (t (UTerm t v)) (t (UTerm t v)) | |
-- | Because equality is skewed, we need a helper to unify two terms. | |
zipMatchRec :: ZipMatch t => t (UTerm t v) -> t (UTerm t v) -> Either (Error t v) [Equality t v] | |
zipMatchRec x y = case zipMatch x y of | |
Nothing -> Left (CannotUnify x y) | |
Just xs -> fmap concat $ traverse (uncurry zipMatchRec') $ toList xs | |
zipMatchRec' :: ZipMatch t => UTerm t v -> UTerm t v -> Either (Error t v) [Equality t v] | |
zipMatchRec' (UVar v) t = Right [Equality v t] | |
zipMatchRec' t (UVar v) = Right [Equality v t] | |
zipMatchRec' (UTerm x) (UTerm y) = zipMatchRec x y | |
solve :: forall t v. (ZipMatch t, Ord v) => [Equality t v] -> Either (Error t v) (Map v (UTerm t v)) | |
-- No constraints: we are done. | |
solve [] = Right Map.empty | |
-- If there are something to do | |
solve (Equality v (UVar v') : eqs) | |
-- If RHS is same variable, skip the constraint | |
| v == v' = solve eqs | |
-- If different, replace v with v', and add that equality to the solution. | |
| otherwise = do | |
let subst :: v -> v | |
subst u | u == v = v' | |
| otherwise = u | |
-- get a solution with replaced variables | |
solution <- solve (map (fmap subst) eqs) | |
-- lookup if v' has a solution, share it if it's there | |
case Map.lookup v' solution of | |
Just t' -> return (Map.insert v t' solution) | |
Nothing -> return (Map.insert v (UVar v') solution) | |
-- If the RHS is a term then... | |
solve (Equality v ut@(UTerm t) : eqs) | |
-- perform occurs check | |
| v `elem` ut = Left (OccursCheck v t) | |
-- otherwise we substitute this equality into all other constraints | |
| otherwise = do | |
let subst :: v -> UTerm t v | |
subst v' | v == v' = ut | |
| otherwise = UVar v' | |
eqs' <- forM eqs $ \eq@(Equality v' ut') -> | |
-- if LHS is is different variable, substitute in RHS | |
if v /= v' | |
then return [ Equality v' (ut' >>= subst) ] | |
-- otherwise we try to unify, or zipMatch | |
else case ut' of | |
UVar u | |
| u == v -> return [ Equality v' ut ] | |
| otherwise -> return [ eq ] | |
UTerm t' -> zipMatchRec t t' | |
solution <- solve (concat eqs') | |
-- this substition flattens the terms | |
let subst' :: v -> UTerm t v | |
subst' u = case Map.lookup u solution of | |
Just new -> new | |
Nothing -> UVar u -- e.g. here we could perform defaulting. | |
return $ Map.insert v (ut >>= subst') solution | |
------------------------------------------------------------------------------- | |
-- Example | |
------------------------------------------------------------------------------- | |
data Ty a | |
= TyBool | |
| TyArrow a a | |
deriving (Functor, Foldable, Traversable, Show) | |
instance ZipMatch Ty where | |
zipMatch TyBool TyBool = Just TyBool | |
zipMatch (TyArrow a b) (TyArrow c d) = Just (TyArrow (a, c) (b, d)) | |
zipMatch _ _ = Nothing | |
-- | | |
-- | |
-- >>> printSolution (solve example) | |
-- 'a' -> UTerm TyBool | |
-- 'b' -> UTerm TyBool | |
-- 'x' -> UTerm (TyArrow (UTerm TyBool) (UTerm TyBool)) | |
-- | |
-- >>> solve [ Equality 'x' $ UTerm $ TyArrow (UVar 'x') (UVar 'x') ] | |
-- Left (OccursCheck 'x' (TyArrow (UVar 'x') (UVar 'x'))) | |
-- | |
-- >>> solve [ Equality 'x' $ UTerm $ TyArrow (UVar 'y') (UVar 'z'), Equality 'x' $ UTerm TyBool ] | |
-- Left (CannotUnify (TyArrow (UVar 'y') (UVar 'z')) TyBool) | |
-- | |
example :: [Equality Ty Char] | |
example = | |
[ Equality 'x' $ UTerm $ TyArrow (UTerm TyBool) (UVar 'a') | |
, Equality 'x' $ UTerm $ TyArrow (UVar 'b') (UVar 'b') | |
] | |
------------------------------------------------------------------------------- | |
-- Show1 | |
------------------------------------------------------------------------------- | |
printSolution :: (Foldable t, Show a, Show b) => t (Map a b) -> IO () | |
printSolution solution = | |
forM_ solution $ \s -> | |
forM (Map.toList s) $ \(v, t) -> | |
putStrLn $ show v ++ " -> " ++ show t | |
instance Show1 Ty where | |
liftShowsPrec _ _ _ TyBool = showString "TyBool" | |
liftShowsPrec sp _ d (TyArrow a b) = showsBinaryWith sp sp "TyArrow" d a b | |
instance (Show v, Show1 t) => Show (UTerm t v) where | |
showsPrec d (UVar v) = showsUnaryWith showsPrec "UVar" d v | |
showsPrec d (UTerm t) = showsUnaryWith (liftShowsPrec showsPrec showList) "UTerm" d t | |
instance (Show v, Show1 t) => Show (Error t v) where | |
showsPrec d (OccursCheck v t) = showsBinaryWith showsPrec (liftShowsPrec showsPrec showList) "OccursCheck" d v t | |
showsPrec d (CannotUnify x y) = showsBinaryWith (liftShowsPrec showsPrec showList) (liftShowsPrec showsPrec showList) "CannotUnify" d x y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment