Skip to content

Instantly share code, notes, and snippets.

@Tarmean
Last active January 11, 2023 16:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Tarmean/c8c986f6c1723be10b7454b53288e989 to your computer and use it in GitHub Desktop.
Save Tarmean/c8c986f6c1723be10b7454b53288e989 to your computer and use it in GitHub Desktop.
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use const" #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
-- | Helpers for AST traversals
module OpenRec where
import Data.Data hiding (gmapM)
import Control.Monad.Writer.Strict (Writer, MonadWriter (tell), execWriter, WriterT, execWriterT)
import Control.Monad ((<=<))
import Debug.Trace (trace, traceM)
import Control.Monad.Trans (lift)
import qualified Data.HashSet as S
import HitTest (Answer(..), hitTest, Oracle(..), typeRepOf)
import Data.Functor.Identity (runIdentity, Identity)
import GHC.Base (oneShot)
import Util (prettyS)
import Prettyprinter (Pretty)
-- | Data gives us a way to get the type of a value, and to traverse its children
type Trans1 m = forall x. Data x => x -> m x
-- | VTable for our traversal
data Ctx m = Ctx {
-- | Continuation when case matched
onSuccess :: Trans1 m,
-- | Continuation when case failed
onFailure :: Trans1 m,
-- | Top-level vtable for recursion
onRecurse :: Trans1 m
}
data RecType = NoRec | BasicRec | ComplexRec
deriving (Eq, Ord, Show)
instance Semigroup RecType where
(<>) = max
-- | A traversal collects a relevant sets of types to visit,
-- and a visitor functions to apply to those types
data Trans m = T {
-- | Types for which @withCtx@ may succeed
relevant :: !(S.HashSet TypeRep),
-- | Should we recurse when the current type is not in @relevant@, but contains @relevant@ types?
-- True, iff @withCtx@ would call @recurse@.
toplevelRecursion :: RecType,
-- | Actualy transformation function
withCtx :: Ctx m -> Trans1 m
}
runT' :: Data a => Trans Identity -> a -> a
runT' trans a0 = runIdentity (runT trans a0)
-- | The core run function
runT :: forall m a. (Monad m, Data a) => Trans m -> a -> m a
runT trans a0 = f a0
where
Oracle oracle = hitTest a0 (relevant trans)
f :: forall x. Data x => x -> m x
f x = case oracle x of
Miss -> pure x
Follow -> case toplevelRecursion trans of
BasicRec -> gmapM f x
NoRec -> pure x
ComplexRec -> go x
Hit _ -> go x
go :: forall x. Data x => x -> m x
go = withCtx trans (Ctx pure pure f)
failed :: Trans m
failed = T mempty NoRec (\Ctx{..} a -> onFailure a)
loggingM :: (Pretty a, Monad m) => String -> m a -> Trans m
loggingM tag logs = T mempty NoRec (\Ctx{..} a -> do
s <- logs
traceM (tag <> prettyS s)
onSuccess a)
-- [Note: Oracle]
-- Types form a tree. We have a root type @a@, which is the type of the value we are transforming.
--
-- - Each type @t@ has a set of sub-types @reachable(t)@, which are accessible from its Data.Data instance.
-- - Each transformation has a set of relevant types
--
-- So our logic is:
-- - If a type @t@ is relevant, apply the transformation (duh)
-- - If we could reach a relevant type from @t@, recurse into @t@
-- - Otherwise, do nothing
--
-- There is a subtlety: If the transformation won't recurse, we shouldn't either!
-- So we also track if the transformation would recurse, and only recurse if it would.
-- | @runT@ specialized for the Writer monad
runQ :: forall a o. (Monoid o, Data a) => Trans (Writer o) -> a -> o
runQ t m = execWriter (runT t m)
-- | @runT@ specialized for the WriterT transformer
runQT :: forall a o m. (Monad m, Monoid o, Data a) => Trans (WriterT o m) -> a -> m o
runQT t m = execWriterT (runT t m)
-- | @gmapM@ from Data.Data, but only using an Applicative constraint
gmapM :: forall m a. (Data a, Applicative m) => (forall d. Data d => d -> m d) -> a -> m a
gmapM f = gfoldl k pure
where
k :: Data d => m (d -> b) -> d -> m b
k c x = c <*> f x
-- | Alternative composition of transformations
-- In @a ||| b@, we only run @b@ if @a@ fails.
(|||) :: forall m. Monad m => Trans m -> Trans m -> Trans m
l ||| r = T relevantTypes containsRecursion trans
where
relevantTypes = relevant l `S.union` relevant r
containsRecursion = toplevelRecursion l <> toplevelRecursion r
trans :: Ctx m -> Trans1 m
trans ctx = withCtx l (ctx { onFailure = withCtx r ctx })
infixl 1 |||
-- | Sequential composition of transformations
-- In @a >>> b@, we only run @b@ if @a@ succeeds.
(>>>) :: forall m. Monad m => Trans m -> Trans m -> Trans m
l >>> r = T relevantTypes containsRecursion trans
where
relevantTypes = relevant l `S.union` relevant r
containsRecursion = toplevelRecursion l
trans :: Ctx m -> Trans1 m
trans ctx = withCtx l ctx{ onSuccess = withCtx r ctx }
infixl 1 >>>
-- | Definite composition of transformations
-- In @a >>> b@, we always run @a@ then @b@.
(&&&) :: forall m. Monad m => Trans m -> Trans m -> Trans m
l &&& r = T relevantTypes containsRecursion trans
where
relevantTypes = relevant l `S.union` relevant r
containsRecursion = toplevelRecursion l <> toplevelRecursion r
trans :: Ctx m -> Trans1 m
trans ctx = withCtx l ctx{ onSuccess = withCtx r ctx, onFailure = withCtx r ctx }
infixl 1 &&&
-- | Core recursion operator
-- Usually, we either want top down recursion
--
-- @
-- tryTrans_ @Expr \case
-- Minus x y | x == y -> Just (Const 0)
-- _ -> Nothing
-- ||| recurse
-- @
--
-- Or bottom up recursion:
--
-- @
-- recurse >>>
-- tryTrans_ @Expr \case
-- Minus x y | x == y -> Just (Const 0)
-- _ -> Nothing
-- @
recurse :: Monad m => Trans m
recurse = T mempty BasicRec $ \Ctx{..} -> onSuccess <=< gmapM onRecurse
tryQueryM :: forall a o m. (Monad m, Monoid o, Data a) => ((forall x. Data x => x -> m o) -> a -> Maybe (m o)) -> Trans (WriterT o m)
tryQueryM f = T (onlyRel @a) NoRec $ \Ctx {..} (a' :: a') -> case eqT @a @a' of
Just Refl -> case f (execWriterT . onRecurse) a' of
Nothing -> onFailure a'
Just o -> lift o >>= tell >> onSuccess a'
Nothing -> onFailure a'
{-# INLINE onlyRel #-}
onlyRel :: forall a. Typeable a => S.HashSet TypeRep
onlyRel = S.singleton (typeRepOf @a)
tryQuery :: forall a o. (Monoid o, Data a) => ((forall x. Data x => x -> o) -> a -> Maybe o) -> Trans (Writer o)
tryQuery f = T (onlyRel @a) NoRec $ \Ctx {..} (a' :: a') -> case eqT @a @a' of
Just Refl -> case f (execWriter . onRecurse) a' of
Nothing -> onFailure a'
Just o -> tell o >> onSuccess a'
Nothing -> onFailure a'
tryQuery_ :: forall a o m. (Monad m, Monoid o, Data a) => (a -> Maybe o) -> Trans (WriterT o m)
tryQuery_ f = T (onlyRel @a) NoRec $ \Ctx {..} (a' :: a') -> case eqT @a @a' of
Just Refl -> case f a' of
Nothing -> onFailure a'
Just o -> tell o *> onSuccess a'
Nothing -> onFailure a'
tryTrans :: forall a m. (Monad m, Data a) => (a -> Maybe a) -> Trans m
tryTrans f = T (onlyRel @a) NoRec $ \Ctx{..} (a::a') -> case eqT @a @a' of
Just Refl -> case f a of
Nothing -> onFailure a
Just a' -> onSuccess a'
Nothing -> onFailure a
tryTransM :: forall a m. (Monad m, Data a) => (Trans1 m -> a -> Maybe (m a)) -> Trans m
tryTransM f = T (onlyRel @a) NoRec $ \Ctx{..} (a::a') -> case eqT @a @a' of
Just Refl -> case f onRecurse a of
Nothing -> onFailure a
Just ma' -> onSuccess =<< ma'
Nothing -> onFailure a
tryTransM_ :: forall a m. (Monad m, Data a) => (a -> Maybe (m a)) -> Trans m
tryTransM_ f = tryTransM (\_ -> f)
tryTrans_ :: forall a m. (Applicative m, Data a) => (a -> Maybe a) -> Trans m
tryTrans_ f = T (onlyRel @a) NoRec $ \Ctx{..} (a::a') -> case eqT @a @a' of
Just Refl -> case f a of
Nothing -> onFailure a
Just b -> onSuccess b
Nothing -> onFailure a
completelyTrans' :: forall m. (Monad m) => Trans m -> Trans m
completelyTrans' f = T (relevant f) (toplevelRecursion f) $ \Ctx{..} a0 ->
let
fixCtx suc = Ctx { onSuccess = trace "sucFix" . fixpoint True, onFailure = if suc then onSuccess else onFailure, onRecurse = onRecurse }
fixpoint :: Data a => Bool -> a -> m a
fixpoint suc = withCtx f (fixCtx suc)
in fixpoint False a0
completelyTrans :: forall a m. (Monad m, Data a) => (a -> Maybe a) -> Trans m
completelyTrans f = tryTrans (fixpoint False)
where
fixpoint suc a = case f a of
Nothing -> if suc then Just a else Nothing
Just a' -> fixpoint True a'
-- stop recursion here, nested `recurse` statements will jump to the block
block :: forall m. Monad m => Trans m -> Trans m
block t = T (relevant t) ComplexRec (\Ctx{onSuccess} x ->
runT t x >>= onSuccess)
(.:) :: (b -> c) -> (a1 -> a2 -> b) -> a1 -> a2 -> c
(.:) = (.).(.)
{-# LANGUAGE MagicHash, UnboxedTuples #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
-- | In a Data.Data traversal, we want to visit a `target` set of types.
-- For every type reachable from a `source` value, cache whether it contains a `target.`
-- Adopted from the Lens package for multiple target types
module HitTest where
import Data.Data
import qualified Data.Proxy as X (Proxy (..))
import qualified Data.Typeable as X (typeRep)
import qualified Data.HashSet as S
import Data.HashSet (HashSet)
import Data.HashMap.Strict (HashMap, (!))
import qualified Data.HashMap.Strict as M
import GHC.IO (IO(..), unsafePerformIO)
import Data.IORef
import qualified Control.Exception as E
import Data.Maybe (fromMaybe)
import GHC.Base (realWorld#)
-------------------------------------------------------------------------------
-- Data Box
-------------------------------------------------------------------------------
data DataBox = forall a. Data a => DataBox
{ dataBoxKey :: TypeRep
, _dataBoxVal :: a
}
{-# INLINE typeRepOf #-}
typeRepOf :: forall a. Typeable a => TypeRep
typeRepOf = typeRep (X.Proxy @a)
dataBox :: forall a. Data a => a -> DataBox
dataBox = DataBox (typeRepOf @a)
{-# INLINE dataBox #-}
-- partial, caught elsewhere
-- | We could grab the children types from the Typeable instance, or GHC generics
-- I think lens uses a Data.Data "traversal" so we do not count any sub-types which hand-written Data.Data instances elide
-- This may be overkill, and slower in praxis? Most handwritten instances only avoid sub-types which are irrelevant anyway.
sybChildren :: Data a => a -> [DataBox]
sybChildren x
| isAlgType dt = do
c <- dataTypeConstrs dt
gmapQ dataBox (fromConstr c `asTypeOf` x)
| otherwise = []
where dt = dataTypeOf x
{-# INLINE sybChildren #-}
-------------------------------------------------------------------------------
-- HitMap
-------------------------------------------------------------------------------
type HitMap = HashMap TypeRep (HashSet TypeRep)
emptyHitMap :: HitMap
emptyHitMap = M.fromList
[ (tRational, S.singleton tInteger)
, (tInteger, S.empty)
] where
tRational = X.typeRep (X.Proxy :: X.Proxy Rational)
tInteger = X.typeRep (X.Proxy :: X.Proxy Integer )
insertHitMap :: DataBox -> HitMap -> HitMap
insertHitMap box hit = fixEq trans (populate box) `mappend` hit where
populate :: DataBox -> HitMap
populate a = f a M.empty where
f (DataBox k v) m
| M.member k hit || M.member k m = m
| cs <- sybChildren v = fs cs $ M.insert k (S.fromList $ map dataBoxKey cs) m
fs [] m = m
fs (x:xs) m = fs xs (f x m)
trans :: HitMap -> HitMap
trans m = M.map f m where
f x = x `mappend` foldMap g x
g x = fromMaybe (hit ! x) (M.lookup x m)
fixEq :: Eq a => (a -> a) -> a -> a
fixEq f = go where
go x | x == x' = x'
| otherwise = go x'
where x' = f x
{-# INLINE fixEq #-}
-- | inlineable 'unsafePerformIO'
inlinePerformIO :: IO a -> a
inlinePerformIO (IO m) = case m realWorld# of
(# _, r #) -> r
{-# INLINE inlinePerformIO #-}
data Cache = Cache HitMap (HashMap TypeRep (HashMap (HashSet TypeRep) (Maybe Follower)))
cache :: IORef Cache
cache = unsafePerformIO $ newIORef $ Cache emptyHitMap M.empty
{-# NOINLINE cache #-}
readCacheFollower :: DataBox -> S.HashSet TypeRep -> Maybe Follower
readCacheFollower b@(DataBox kb _) ka = inlinePerformIO $
readIORef cache >>= \ (Cache hm m) -> case M.lookup kb m >>= M.lookup ka of
Just a -> return a
Nothing -> E.try (return $! insertHitMap b hm) >>= \case
Left err@E.SomeException{} -> error (show err) -- atomicModifyIORef cache $ \(Cache hm' n) -> (Cache hm' (insert2 kb ka Nothing n), Nothing)
Right hm' | fol <- Just (follower kb ka hm') -> atomicModifyIORef cache $ \(Cache _ n) -> (Cache hm' (insert2 kb ka fol n), fol)
insert2 :: TypeRep -> HashSet TypeRep -> a -> HashMap TypeRep (HashMap (HashSet TypeRep) a) -> HashMap TypeRep (HashMap (HashSet TypeRep) a)
insert2 x y v = M.insertWith (const $ M.insert y v) x (M.singleton y v)
{-# INLINE insert2 #-}
-------------------------------------------------------------------------------
-- Answers
-------------------------------------------------------------------------------
data Answer b
= Hit TypeRep
| Follow
| Miss
-------------------------------------------------------------------------------
-- Oracles
-------------------------------------------------------------------------------
newtype Oracle = Oracle { fromOracle :: forall t. Typeable t => t -> Answer t }
hitTest :: forall a. (Data a) => a -> HashSet TypeRep -> Oracle
hitTest a b = Oracle $ \(c :: c) ->
let tyA = typeOf c
in
if tyA `S.member` b
then Hit tyA
else case readCacheFollower (dataBox a) b of
Just p | not (p (typeOf c)) -> Miss
_ -> Follow
-------------------------------------------------------------------------------
-- Traversals
-------------------------------------------------------------------------------
-- biplateData :: forall f s a. (Applicative f, Data s) => (forall c. Typeable c => c -> Answer c a) -> (a -> f a) -> s -> f s
-- biplateData o f = go2 where
-- go :: Data d => d -> f d
-- go = gfoldl (\x y -> x <*> go2 y) pure
-- go2 :: Data d => d -> f d
-- go2 s = case o s of
-- Hit a -> f a
-- Follow -> go s
-- Miss -> pure s
-- {-# INLINE biplateData #-}
-- uniplateData :: forall f s a. (Applicative f, Data s) => (forall c. Typeable c => c -> Answer c a) -> (a -> f a) -> s -> f s
-- uniplateData o f = go where
-- go :: Data d => d -> f d
-- go = gfoldl (\x y -> x <*> go2 y) pure
-- go2 :: Data d => d -> f d
-- go2 s = case o s of
-- Hit a -> f a
-- Follow -> go s
-- Miss -> pure s
-- {-# INLINE uniplateData #-}
-------------------------------------------------------------------------------
-- Follower
-------------------------------------------------------------------------------
part :: (a -> Bool) -> HashSet a -> (HashSet a, HashSet a)
part p s = (S.filter p s, S.filter (not . p) s)
{-# INLINE part #-}
type Follower = TypeRep -> Bool
follower :: TypeRep -> HashSet TypeRep -> HitMap -> Follower
follower a b m
| S.null hit = const False
| S.null miss = const True
-- Skip this case so unknown types are always visited
-- | S.size hit < S.size miss = (`S.member` hit)
| otherwise = \k -> not (S.member k miss)
where
(hit, miss) = part (\x -> not $ S.null (S.intersection b (m ! x))) (S.insert a (m ! a))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment