Skip to content

Instantly share code, notes, and snippets.

@gelisam
Last active November 4, 2023 03:44
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gelisam/90ca9dff6906abacf11387d960d46cab to your computer and use it in GitHub Desktop.
Save gelisam/90ca9dff6906abacf11387d960d46cab to your computer and use it in GitHub Desktop.
A recursion scheme for mutually-recursive types
-- Defining a custom recursion scheme to manipulate two mutually-recursive
-- types, in the context of a toy bidirectional type checker.
{-# LANGUAGE DerivingStrategies, GeneralizedNewtypeDeriving, ScopedTypeVariables #-}
module Main where
import Test.DocTest
import Control.Monad (when)
import Data.Bifunctor (Bifunctor(bimap))
import Data.Bifoldable (Bifoldable(bifoldMap), bitraverse_)
import qualified Data.List as List
data Type
= Arr Type Type
| UnitType
| IntType
deriving (Eq, Show)
type Context = [(String, Type)]
-- Let's begin by defining our language's terms.
--
-- data Term
-- = Lam String Term
-- | Var String
-- | App Term Term
-- | Ann Term Type
-- | UnitTerm
-- | IntTerm Int
-- | Negate Term
-- | Add Term Term
-- | Sub Term Term
-- | Mul Term Term
--
-- The definition is commented out because there are two changes I want to
-- make. I want to partition it into "normal" and "neutral" forms:
--
-- data Normal
-- = Lam String Normal
-- | Neu Neutral
--
-- data Neutral
-- = Var String
-- | App Neutral Normal
-- | Ann Normal Type
-- | UnitTerm
-- | IntTerm Int
-- | Negate Normal
-- | Add Normal Normal
-- | Sub Normal Normal
-- | Mul Normal Normal
--
-- And I want to turn it into a base functor:
--
-- data TermF term
-- = Lam String term
-- | Var String
-- | App term term
-- | Ann term Type
-- | UnitTerm
-- | IntTerm Int
-- | Negate term
-- | Add term term
-- | Sub term term
-- | Mul term term
-- Let's discuss the partition first. This is a relatively common trick which
-- avoids the two catch-all cases you'd otherwise find at the end of the
-- typechecker's 'check' and 'infer' functions:
--
-- check :: Context -> Term -> Type -> Either String ()
-- check ctx (Lam x body) tp = ...
-- check ...
-- check ctx term expectedType = do
-- actualType <- infer ctx term
-- when (actualType /= expectedType) $ do
-- Left $ "expected " ++ show expectedType
-- ++ ", found " ++ show actualType
--
-- infer :: Context -> Term -> Either String Type
-- infer ctx (Var x) = ...
-- infer ...
-- infer ctx term = do
-- Left $ "ambiguous type, please add a type annotation "
-- ++ "around " ++ show term
--
-- Thanks to the partitioned representation, the 'check' catch-all case is now
-- a regular case (the 'Neu' case), and the 'infer' catch-all case is dropped
-- entirely because the error condition is now unrepresentable. Cool!
-- Next, the base functor representation. A single recursive type can be
-- turned into a base functor by adding a type parameter and replacing all the
-- recursive occurrences of the type by that type paramter. That trick doesn't
-- work here, because the partitioned representation means we now have _two_
-- mutually-recursive types, and thus two different recursive occurrences to
-- replace. We must thus add _two_ type parameters, replacing each recursive
-- occurrence with the appropriate one. The result is two base _bifunctors_!
data NormalF normal neutral
= Lam String normal
| Neu neutral
deriving Show
data NeutralF normal neutral
= Var String
| App neutral normal
| Ann normal Type
| UnitTerm
| IntTerm Int
| Negate normal
| Add normal normal
| Sub normal normal
| Mul normal normal
deriving Show
instance Bifunctor NormalF where
bimap f _ (Lam name x)
= Lam name (f x)
bimap _ g (Neu y)
= Neu (g y)
instance Bifunctor NeutralF where
bimap _ _ (Var name)
= Var name
bimap f g (App y x)
= App (g y) (f x)
bimap f _ (Ann x tp)
= Ann (f x) tp
bimap _ _ UnitTerm
= UnitTerm
bimap _ _ (IntTerm n)
= IntTerm n
bimap f _ (Negate x)
= Negate (f x)
bimap f _ (Add x1 x2)
= Add (f x1) (f x2)
bimap f _ (Sub x1 x2)
= Sub (f x1) (f x2)
bimap f _ (Mul x1 x2)
= Mul (f x1) (f x2)
instance Bifoldable NormalF where
bifoldMap f _ (Lam _ x)
= f x
bifoldMap _ g (Neu y)
= g y
instance Bifoldable NeutralF where
bifoldMap _ _ (Var _)
= mempty
bifoldMap f g (App y x)
= g y <> f x
bifoldMap f _ (Ann x _)
= f x
bifoldMap _ _ UnitTerm
= mempty
bifoldMap _ _ (IntTerm _)
= mempty
bifoldMap f _ (Negate x)
= f x
bifoldMap f _ (Add x1 x2)
= f x1 <> f x2
bifoldMap f _ (Sub x1 x2)
= f x1 <> f x2
bifoldMap f _ (Mul x1 x2)
= f x1 <> f x2
newtype Normal = Normal
{ unNormal :: NormalF Normal Neutral }
deriving newtype Show
newtype Neutral = Neutral
{ unNeutral :: NeutralF Normal Neutral }
deriving newtype Show
-- Next, I want to define a recursion scheme which captures the mutual
-- recursion between 'check' and 'infer'. The usual recursion scheme for
-- capturing mutual recursion is a mutumorphism:
--
-- mutu
-- :: (f (a,b) -> a) -- first algebra
-- -> (f (a,b) -> b) -- second algebra
-- -> ( Fix f -> a -- first mutually-recursive function
-- , Fix f -> b -- second mutually-recursive function
-- )
--
-- The idea is that each algebra has access to the recursive results for both
-- of the mutually-recursive functions, which is the recursion-schemes
-- equivalent of both functions being able to call each other.
--
-- Because of our partitioned representation, however, we need a variant of
-- 'mutu' in which each algebra manipulates a different 'f'. This variant is
-- pretty easy to define:
myMutu
:: forall a b
. (NormalF a b -> a)
-> (NeutralF a b -> b)
-> ( Normal -> a
, Neutral -> b
)
myMutu algA algB
= (fA, fB)
where
fA :: Normal -> a
fA = algA . bimap fA fB . unNormal
fB :: Neutral -> b
fB = algB . bimap fA fB . unNeutral
-- |
-- We are now ready to start implementing the type checker. Using it to check
-- that the expression
--
-- \s z -> s (s z)
--
-- has (among others) the type
--
-- (Int -> Int) -> Int -> Int
--
-- will look like this:
--
-- >>> :{
-- check []
-- ( Normal $ Lam "s"
-- $ Normal $ Lam "z"
-- $ Normal $ Neu
-- $ Neutral $ App (Neutral $ Var "s")
-- $ Normal $ Neu
-- $ Neutral $ App (Neutral $ Var "s")
-- $ Normal $ Neu
-- $ Neutral $ Var "z"
-- )
-- (Arr (Arr IntType IntType) (Arr IntType IntType))
-- :}
-- Right ()
check
:: Context
-> Normal
-> Type
-> Either String ()
infer
:: Context
-> Neutral
-> Either String Type
(check, infer)
= -- I want my 'check' and 'infer' functions to have the standard API in
-- which the context is given before the term, just like in the typing
-- judgement "Γ ⊢ e : 𝜏". However, 'myMutu' requires the term to be the
-- first argument, so a bit of argument juggling is needed.
( \ctx normal tp -> check' normal ctx tp
, \ctx neutral -> infer' neutral ctx
)
where
-- The reason the term must come first is because 'myMutu' requires the
-- two algebras to have these types:
--
-- Normal a b -> a
-- Neutral a b -> b
--
-- Thus all the remaining arguments must be obtained by specializing 'a'
-- and 'b' to function types:
--
-- a ~ (Context -> Type -> Either String ())
-- b ~ (Context -> Either String Type)
check'
:: Normal
-> Context
-> Type
-> Either String ()
infer'
:: Neutral
-> Context
-> Either String Type
(check', infer')
= -- I still want to use the natural parameter order in my
-- implementation of the two algebras though, so more argument
-- juggling is needed.
myMutu
(\normal ctx tp -> checkF ctx normal tp)
(\neutral ctx -> inferF ctx neutral)
-- We are finally ready to implement the two algebras! Depending on
-- whether a recursive position normally contains a normal or a neutral
-- term, that position in the base bifunctor either contains the result of
-- partially applying 'check' or partially applying 'infer' to that
-- sub-term. Thus, we need to supply the remaining arguments in order to
-- get the resulting @Either String ()@ or @Either String Type@.
checkF
:: Context
-> NormalF
(Context -> Type -> Either String ())
(Context -> Either String Type)
-> Type
-> Either String ()
checkF ctx (Lam x checkBody) tp = do
case tp of
Arr tArg tOut -> do
checkBody ((x,tArg) : ctx) tOut
_ -> do
Left $ "expected " ++ show tp ++ ", found lambda"
checkF ctx (Neu inferNeutral) tp = do
tp' <- inferNeutral ctx
when (tp /= tp') $ do
Left $ "expected " ++ show tp ++ ", found " ++ show tp'
inferF
:: Context
-> NeutralF
(Context -> Type -> Either String ())
(Context -> Either String Type)
-> Either String Type
inferF ctx (Var name) = do
case List.lookup name ctx of
Nothing -> do
Left $ "variable " ++ show name ++ " not in scope"
Just tp -> do
pure tp
inferF ctx (App inferFun checkArg) = do
tp <- inferFun ctx
case tp of
Arr tArg tOut -> do
checkArg ctx tArg
pure tOut
_ -> do
Left $ show tp ++ " is not a function"
inferF ctx (Ann checkTerm tp) = do
checkTerm ctx tp
pure tp
inferF _ UnitTerm = do
pure UnitType
inferF ctx numericTerm = do
-- And now, the moment we are all waiting for: the payoff! All of this
-- was a lot more verbose than the general-recursion version, so we
-- better get something in return.
--
-- What we get is the ability to write a single generic case for
-- 'IntTerm', 'Negate', 'Add', 'Sub', and 'Mul'. We can do that because
-- we can use the base bifunctor's Bifoldable instance to traverse all
-- of the sub-terms, without having to know how many sub-terms there are
-- nor which ones are normal or neutral.
let onNormal checkSubTerm = do
checkSubTerm ctx IntType
onNeutral inferSubTerm = do
actualType <- inferSubTerm ctx
when (actualType /= IntType) $ do
Left $ "numeric operation requires Int argument, "
++ "found " ++ show actualType
bitraverse_ onNormal onNeutral numericTerm
pure IntType
-- So, was all of that worth it, just to save on a few easy cases at the end
-- of the type checker? It depends! Some languages have a very large number of
-- terms, so the cost might be worth paying in that case. Some teams value
-- simplicity over succinctness and so will never be ready to pay the cost.
--
-- The real benefit, really, is the knowledge gained along the way: whether
-- you're writing a typechecker or something else, you now have another tool
-- in your toolbox, waiting for you to encounter a challenge it solves well.
main :: IO ()
main = do
putStrLn "typechecks."
test :: IO ()
test = do
doctest ["src/Main.hs"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment