Skip to content

Instantly share code, notes, and snippets.

@DarinM223
Last active April 9, 2023 16:00
Show Gist options
  • Save DarinM223/44c7bf5d0a98232f6b01a7435a570810 to your computer and use it in GitHub Desktop.
Save DarinM223/44c7bf5d0a98232f6b01a7435a570810 to your computer and use it in GitHub Desktop.
Translation of `sound_eager.ml` to Haskell https://okmij.org/ftp/ML/generalization.html
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module SoundEager where
import Control.Monad (unless)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.ST (runST)
import Data.Char (chr, ord)
import Data.Functor ((<&>))
import Data.Functor.Classes (Eq1, eq1, liftEq)
import Data.Functor.Identity (Identity (Identity), runIdentity)
import Data.Maybe (fromJust)
import Data.Primitive (MutVar, newMutVar, readMutVar, writeMutVar)
import Data.Primitive.PrimVar (PrimVar, modifyPrimVar, newPrimVar, readPrimVar, writePrimVar)
import Data.Text (Text)
import Data.Text qualified as T
import Unsafe.Coerce (unsafeCoerce)
type Varname = Text
data Exp
= Var Varname
| App Exp Exp
| Lam Varname Exp
| Let Varname Exp Exp
type Qname = Text
type Level = Int
data Typ r
= TVar (r (Tv r))
| QVar Qname
| TArrow (Typ r) (Typ r)
instance Show (Typ Identity) where
show (TVar (Identity tv)) = "TVar (" ++ show tv ++ ")"
show (QVar name) = "QVar " ++ T.unpack name
show (TArrow ty1 ty2) = show ty1 ++ " -> " ++ show ty2
instance (Eq1 r) => Eq (Typ r) where
(==) (TVar a) (TVar b) = eq1 a b
(==) (QVar a) (QVar b) = a == b
(==) (TArrow a1 b1) (TArrow a2 b2) = a1 == a2 && b1 == b2
(==) _ _ = False
data Tv r = Unbound Text Level | Link (Typ r) deriving (Eq)
instance Show (Tv Identity) where
show (Unbound name level) =
"Unbound (" ++ T.unpack name ++ " " ++ show level ++ ")"
show (Link ty) = "Link (" ++ show ty ++ ")"
type IntRef m = PrimVar (PrimState m) Int
newtype Ref m a = Ref {unRef :: MutVar (PrimState m) a} deriving (Eq)
instance Eq1 (Ref m) where
liftEq _ (Ref a) (Ref b) = a == unsafeCoerce b
readRef :: (PrimMonad m) => Ref m a -> m a
readRef = readMutVar . unRef
{-# INLINE readRef #-}
writeRef :: (PrimMonad m) => Ref m a -> a -> m ()
writeRef ref = writeMutVar (unRef ref)
{-# INLINE writeRef #-}
transformRef ::
(Monad m) =>
(Tv r' -> m (r' (Tv r'))) ->
(r (Tv r) -> m (Tv r)) ->
Typ r ->
m (Typ r')
transformRef constr f (TVar ref) = TVar <$> (f ref >>= goTv >>= constr)
where
goTv (Unbound name level) = pure $ Unbound name level
goTv (Link typ) = Link <$> transformRef constr f typ
transformRef _ _ (QVar name) = pure $ QVar name
transformRef constr f (TArrow ty1 ty2) =
TArrow <$> transformRef constr f ty1 <*> transformRef constr f ty2
toIdentity' :: (PrimMonad m) => Typ (Ref m) -> m (Typ Identity)
toIdentity' = transformRef (pure . Identity) readRef
toRef :: (PrimMonad m) => Typ Identity -> m (Typ (Ref m))
toRef = transformRef (fmap Ref . newMutVar) (pure . runIdentity)
toIdentity :: (PrimMonad m) => Typ (Ref m) -> m (Typ Identity)
toIdentity (TVar ref) = TVar . Identity <$> (readRef ref >>= goTv)
where
goTv (Unbound name level) = pure $ Unbound name level
goTv (Link typ) = Link <$> toIdentity typ
toIdentity (QVar name) = pure $ QVar name
toIdentity (TArrow ty1 ty2) = TArrow <$> toIdentity ty1 <*> toIdentity ty2
gensym :: (PrimMonad m) => (?gensym :: IntRef m) => m Text
gensym = do
n <- readPrimVar ?gensym
writePrimVar ?gensym (n + 1)
if n < 26
then pure $ T.singleton (chr (ord 'a' + n))
else pure $ "t" <> T.pack (show n)
enterLevel :: (PrimMonad m) => (?level :: IntRef m) => m ()
enterLevel = modifyPrimVar ?level (+ 1)
leaveLevel :: (PrimMonad m) => (?level :: IntRef m) => m ()
leaveLevel = modifyPrimVar ?level (subtract 1)
type Constr m = (?gensym :: IntRef m, ?level :: IntRef m)
newVar :: (PrimMonad m, Constr m) => m (Typ (Ref m))
newVar = do
tv <- Unbound <$> gensym <*> readPrimVar ?level
TVar . Ref <$> newMutVar tv
occurs :: (PrimMonad m) => Ref m (Tv (Ref m)) -> Typ (Ref m) -> m ()
occurs tvr = \case
TVar tvr' | tvr == tvr' -> error "Occurs check"
TVar tv -> do
readRef tv >>= \case
Unbound name l' -> do
minLevel <- readRef tvr <&> \case Unbound _ l -> min l l'; _ -> l'
writeRef tv (Unbound name minLevel)
Link ty -> occurs tvr ty
TArrow t1 t2 -> occurs tvr t1 >> occurs tvr t2
_ -> pure ()
unify :: (PrimMonad m) => Typ (Ref m) -> Typ (Ref m) -> m ()
unify t1 t2 = unless (t1 == t2) $ do
(tv1, tv2) <- (,) <$> getTv t1 <*> getTv t2
case (tv1, tv2, t1, t2) of
(Just Unbound {}, _, TVar tv, t') -> occurs tv t' >> writeRef tv (Link t')
(_, Just Unbound {}, t', TVar tv) -> occurs tv t' >> writeRef tv (Link t')
(Just (Link t1'), _, _, t2') -> unify t1' t2'
(_, Just (Link t2'), t1', _) -> unify t1' t2'
(_, _, TArrow tyl1 tyl2, TArrow tyr1 tyr2) ->
unify tyl1 tyr1 >> unify tyl2 tyr2
_ -> error "Invalid types for unification"
where
getTv = \case TVar ref -> Just <$> readMutVar (unRef ref); _ -> pure Nothing
type Env m = [(Varname, Typ (Ref m))]
gen :: (PrimMonad m, Constr m) => Typ (Ref m) -> m (Typ (Ref m))
gen (TVar ref) =
readRef ref >>= \case
Unbound name l ->
readPrimVar ?level <&> \currLevel ->
if l > currLevel then QVar name else TVar ref
Link ty -> gen ty
gen (TArrow ty1 ty2) = TArrow <$> gen ty1 <*> gen ty2
gen ty = pure ty
inst :: (PrimMonad m, Constr m) => Typ (Ref m) -> m (Typ (Ref m))
inst = fmap fst . go []
where
go sub (QVar name) = case lookup name sub of
Just ty -> pure (ty, sub)
Nothing -> (\ty -> (ty, (name, ty) : sub)) <$> newVar
go sub (TVar ref) =
readRef ref >>= \case
Link ty -> go sub ty
Unbound {} -> pure (TVar ref, sub)
go sub (TArrow ty1 ty2) = do
(ty1', sub') <- go sub ty1
(ty2', sub'') <- go sub' ty2
pure (TArrow ty1' ty2', sub'')
typeof :: (PrimMonad m, Constr m) => Env m -> Exp -> m (Typ (Ref m))
typeof env (Var x) = inst $ fromJust $ lookup x env
typeof env (Lam x e) = do
tyX <- newVar
TArrow tyX <$> typeof ((x, tyX) : env) e
typeof env (App fun arg) = do
tyFun <- typeof env fun
tyArg <- typeof env arg
tyRes <- newVar
tyRes <$ unify tyFun (TArrow tyArg tyRes)
typeof env (Let x e rest) = do
tyE <- enterLevel *> typeof env e <* leaveLevel
tyE' <- gen tyE
typeof ((x, tyE') : env) rest
runInfer :: (PrimMonad m) => ((Constr m) => m a) -> m a
runInfer f = do
gensym_ <- newPrimVar 0
level_ <- newPrimVar 1
let ?gensym = gensym_; ?level = level_ in f
testAlg :: Exp -> Typ Identity
testAlg e = runST (runInfer (typeof [] e) >>= toIdentity)
test1 :: Typ Identity
test1 = testAlg $ Lam "x" (Var "x")
test2 :: Typ Identity
test2 = testAlg $ Lam "x" (Lam "y" (App (Var "x") (Var "y")))
testOccurs :: Typ Identity
testOccurs = testAlg $ Lam "y" $ App (Var "y") (Lam "z" (App (Var "y") (Var "z")))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment