Skip to content

Instantly share code, notes, and snippets.

@iamahuman
Last active April 4, 2021 13:02
Show Gist options
  • Save iamahuman/65aaa57ef6e9dbfb47a26a898d37a082 to your computer and use it in GitHub Desktop.
Save iamahuman/65aaa57ef6e9dbfb47a26a898d37a082 to your computer and use it in GitHub Desktop.
A pure ST monad transformer with coercible references
{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ExistentialQuantification #-}
-----------------------------------------------------------------------------
-- |
-- Module : Data.PureRef
-- Copyright : (c) Jinoh Kang, 2019
-- License : MIT
--
-- Maintainer : unmaintained
-- Stability : experimental
-- Portability : portable
--
-- A pure version of the 'ST' Monad (partial).
--
-- Do not use this in production. Use 'Data.STRef' instead.
--
-----------------------------------------------------------------------------
module Data.PureRef (
PRT, PR, PRef,
newPRef, readPRef, writePRef, deletePRef,
modifyPRef, modifyPRef',
runPRT, runPR
) where
#ifndef MIN_VERSION_base
#define CMP_VER_3(x,y,z,a,b,c) \
((x)<(a)||((x)==(a)&&((y)<(b)||((y)==(b)&&((z)<=(c))))))
#define CMP_VER_GHC(x,y,z) CMP_VER_3(x,y,z,__GLASGOW_HASKELL__,__GLASGOW_HASKELL_PATCHLEVEL1__,__GLASGOW_HASKELL_PATCHLEVEL2__)
#ifndef __GLASGOW_HASKELL__
#define MIN_VERSION_base(x,y,z) 0
#elif CMP_VER_GHC(806,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,12,0)
#elif CMP_VER_GHC(804,2,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,11,1)
#elif CMP_VER_GHC(804,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,11,0)
#elif CMP_VER_GHC(802,2,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,10,1)
#elif CMP_VER_GHC(802,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,10,0)
#elif CMP_VER_GHC(800,2,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,9,1)
#elif CMP_VER_GHC(800,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,9,0)
#elif CMP_VER_GHC(710,3,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,8,2)
#elif CMP_VER_GHC(710,2,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,8,1)
#elif CMP_VER_GHC(710,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,8,0)
#elif CMP_VER_GHC(708,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,7,0)
#elif CMP_VER_GHC(706,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,6,0)
#elif CMP_VER_GHC(704,2,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,5,1)
#elif CMP_VER_GHC(704,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,5,0)
#elif CMP_VER_GHC(702,2,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,4,1)
#elif CMP_VER_GHC(702,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,4,0)
#elif CMP_VER_GHC(700,2,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,3,1)
#elif CMP_VER_GHC(700,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,3,0)
#elif CMP_VER_GHC(612,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,2,0)
#elif CMP_VER_GHC(610,2,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,1,0)
#elif CMP_VER_GHC(610,1,0)
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,0,0)
#else
#define MIN_VERSION_base(x,y,z) 0
#endif
#endif
import Control.Applicative (Applicative, Alternative, pure,
(<*>), (<*), (*>), liftA2,
empty, (<|>), some, many)
import Control.Monad (Monad, MonadPlus, return, (>>=),
fail, mzero, mplus)
import Control.Monad.Fix (MonadFix, mfix)
#if MIN_VERSION_base(4,9,0)
import qualified Control.Monad.Fail as Fail (MonadFail, fail)
#endif
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans (MonadTrans, lift)
#if MIN_VERSION_base(4,7,0)
import Data.Coerce (Coercible, coerce)
#endif
import Data.Dynamic (Dynamic, fromDynamic, toDyn)
import Data.Functor (Functor)
import Data.Functor.Identity (Identity, runIdentity)
#if MIN_VERSION_base(4,12,0)
import Data.Functor.Contravariant (Contravariant, contramap)
#endif
import Data.Typeable (Typeable)
import qualified Data.IntMap.Lazy as M
modname :: String
modname = "Data.PureRef"
{-# INLINE modname #-}
type St = M.IntMap Dynamic
-- | A pure strict state-transformer monad transformer.
-- A computation of type @'PR' s a@ transforms an internal state indexed
-- by @s@, and returns a value of type @a@.
-- The @s@ parameter serves to keep the internal states
-- of different invocations of 'runPR' separate from each other.
--
-- The '>>=' and '>>' operations are strict in the state (though not in
-- values stored in the state). For example,
--
-- @'runPR' ('writePRef' _|_ v >>= f) = _|_@
newtype PRT s m a = PRT { unPRT :: St -> m (a, St) }
-- | A pure strict state-transformer monad.
type PR s = PRT s Identity
newtype CK a b = CK M.Key
-- | a value of type @PRef s a@ is a mutable variable
-- in state thread @s@, containing a value of type @a@
#if MIN_VERSION_base(4,7,0)
data PRef s a = forall b. (Coercible a b, Typeable b) =>
PRef {-# UNPACK #-}!(CK a b)
#else
data PRef s a = Typeable a => PRef {-# UNPACK #-}!(CK a a)
#endif
instance Eq (PRef s a) where
PRef (CK x) == PRef (CK y) = x == y
{-# INLINE (==) #-}
PRef (CK x) /= PRef (CK y) = x /= y
{-# INLINE (/=) #-}
#if MIN_VERSION_base(4,7,0)
cab :: Coercible a b => CK a b -> a -> b
cab _ = coerce
{-# INLINE cab #-}
cba :: Coercible a b => CK a b -> b -> a
cba _ = coerce
{-# INLINE cba #-}
#endif
type MaybeS = Either String
ro :: Monad m => CK a b
-> (Maybe Dynamic -> MaybeS a)
-> PRT s m a
ro (CK k) f = PRT g where
g s = case f (M.lookup k s) of
Left e -> fail e
Right v -> return (v, s)
rw :: Monad m => CK a b
-> (Maybe Dynamic -> MaybeS (Maybe Dynamic))
-> PRT s m ()
rw (CK k) f = PRT g where
g s = case M.alterF f k s of
Left e -> fail e
Right v -> v `seq` return ((), v)
pfail :: String -> String -> MaybeS a
pfail n s = Left $ modname ++ ('.':n ++ (':':' ':s))
modify :: String
-> (CK x y -> (Maybe Dynamic -> MaybeS b) -> a)
-> (Dynamic -> MaybeS b)
-> (CK x y -> a)
modify name mode fn ck@(CK k) = mode ck f where
f (Just x) = fn x
f _ = pfail name $ "Lost reference to key " ++ show k
#if MIN_VERSION_base(4,7,0)
oper :: (Typeable y, Coercible x y) => String
-> (CK x y -> (Maybe Dynamic -> MaybeS b) -> a)
-> (x -> b)
-> (CK x y -> a)
#else
oper :: Typeable x => String
-> (CK x x -> (Maybe Dynamic -> MaybeS b) -> a)
-> (x -> b)
-> (CK x x -> a)
#endif
oper name mode fn ck@(CK k) = modify name mode f ck where
f d = case fromDynamic d of
Just bv -> Right (fn x) where
#if MIN_VERSION_base(4,7,0)
x = cba ck bv
#else
x = bv
#endif
_ -> pfail name $ "Unexpected type for " ++
shows k (':':' ':show d)
#if MIN_VERSION_base(4,7,0)
dg :: (Coercible a b, Typeable b) =>
CK a b -> a -> Maybe Dynamic
dg ck x = Just (toDyn (cab ck x))
#else
dg :: Typeable a => CK a a -> a -> Maybe Dynamic
dg _ x = Just (toDyn x)
#endif
{-# INLINE dg #-}
stub :: a -> MaybeS (Maybe Dynamic)
stub _ = Right $ Just $ errorWithoutStackTrace $ modname ++
": Attempted to use deleted pure reference"
{-# INLINE stub #-}
-- | /O(log n)/. Build a new 'PRef' in the current state thread.
--
-- If the @'Typeable' a@ constraint bothers you,
-- there is always 'Data.STRef'.
newPRef :: (Monad m, Typeable a) => a -> PRT s m (PRef s a)
newPRef v = PRT $ \ s ->
let k = if M.null s then 0 else 1 + fst (M.findMax s)
r = PRef (CK k :: Typeable a => CK a a)
s' = M.insert k (toDyn v) s
e = error $ modname ++ ": out of references"
in if k < 0 then e else r `seq` s' `seq` return (r, s')
-- | /O(log n)/. Read the value of a 'PRef'.
readPRef :: Monad m => PRef s a -> PRT s m a
readPRef (PRef c) = oper "readPRef" ro (\ x -> x) c
-- | /O(log n)/. Write a new value into a 'PRef'.
writePRef :: Monad m => PRef s a -> a -> PRT s m ()
writePRef (PRef c) v = oper "writePRef" rw (\ _ -> dg c v) c
-- | /O(log n)/. Mutate the contents of a 'PRef'.
modifyPRef :: Monad m => PRef s a -> (a -> a) -> PRT s m ()
modifyPRef (PRef c) f = oper "modifyPRef" rw (dg c . f) c
-- | /O(log n)/. Strict version of 'modifyPRef'.
modifyPRef' :: Monad m => PRef s a -> (a -> a) -> PRT s m ()
modifyPRef' (PRef c) f = oper "modifyPRef'" rw ((dg c $!) . f) c
-- | /O(log n)/. Revoke a 'PRef' so that future references to it
-- are no longer valid.
--
-- Deleteing an already deleted reference does nothing,
-- meaning @('>>' 'deletePRef' a)@ is idempotent.
--
-- Different 'PRef's are never equal to each other regardless of
-- whether any of them is deleted or not.
--
-- All references, even deleted, stay (and thus leak)
-- forever in the PRT monad. If this bothers you,
-- there is always 'Data.STRef'.
deletePRef :: Monad m => PRef s a -> PRT s m ()
deletePRef (PRef c) = modify "deletePRef" rw stub c
-- | Return the value computed by a state transformer computation.
-- The @forall@ ensures that the internal state used by the 'PRT'
-- computation is inaccessible to the rest of the program.
runPRT :: Monad m => (forall s. PRT s m a) -> m a
runPRT (PRT t) = t M.empty >>= return . fst
{-# INLINE runPRT #-}
-- | Return the value computed by a state transformer computation.
-- The @forall@ ensures that the internal state used by the 'PR'
-- computation is inaccessible to the rest of the program.
runPR :: (forall s. PR s a) -> a
runPR (PRT t) = fst (runIdentity (t M.empty))
{-# INLINE runPR #-}
instance Functor f => Functor (PRT s f) where
fmap f (PRT m) = PRT g where
g s = fmap (\ ~(a, s') -> (f a, s')) (m s)
{-# INLINE fmap #-}
instance (Functor m, Monad m) => Applicative (PRT s m) where
pure a = PRT (\ s -> return (a, s))
{-# INLINE pure #-}
PRT mf <*> PRT mx = PRT g where
g s = mf s >>= \ ~(f, s') ->
fmap (\ ~(x, s'') -> (f x, s'')) (mx s')
{-# INLINE (<*>) #-}
ma *> mb = ma >>= \ _ -> mb
{-# INLINE (*>) #-}
PRT ma <* PRT mb = PRT g where
g s = ma s >>= \ ~(a, s') ->
fmap (\ ~(_, s'') -> (a, s'')) (mb s')
{-# INLINE (<*) #-}
liftA2 f (PRT mx) (PRT my) = PRT g where
g s = mx s >>= \ ~(x, s') ->
fmap (\ ~(y, s'') -> (f x y, s'')) (my s')
{-# INLINE liftA2 #-}
instance Monad m => Monad (PRT s m) where
#if !MIN_VERSION_base(4,8,0)
return a = PRT (\ s -> (a, s))
{-# INLINE return #-}
#endif
PRT m >>= f = PRT g where
g s = m s >>= \ ~(a, s') -> unPRT (f a) s'
{-# INLINE (>>=) #-}
#if !MIN_VERSION_base(4,13,0)
fail s = PRT (\ _ -> fail s)
{-# INLINE fail #-}
#endif
instance (Monad m, Alternative m) => Alternative (PRT s m) where
empty = PRT (\ _ -> empty)
{-# INLINE empty #-}
PRT m <|> PRT n = PRT (\ s -> m s <|> n s)
{-# INLINE (<|>) #-}
#if MIN_VERSION_base(4,9,0)
instance Fail.MonadFail m => Fail.MonadFail (PRT s m) where
fail s = PRT (\ _ -> Fail.fail s)
{-# INLINE fail #-}
#endif
instance MonadPlus m => MonadPlus (PRT s m) where
mzero = PRT (\ _ -> mzero)
{-# INLINE mzero #-}
PRT m `mplus` PRT n = PRT (\ s -> m s `mplus` n s)
{-# INLINE mplus #-}
instance MonadFix m => MonadFix (PRT s m) where
mfix f = PRT (\ s -> mfix (\ ~(a, _) -> unPRT (f a) s))
{-# INLINE mfix #-}
instance MonadTrans (PRT s) where
lift m = PRT (\ s -> m >>= \ a -> return (a, s))
{-# INLINE lift #-}
instance MonadIO m => MonadIO (PRT s m) where
liftIO = lift . liftIO
{-# INLINE liftIO #-}
#if MIN_VERSION_base(4,12,0)
instance Contravariant m => Contravariant (PRT s m) where
contramap f m = PRT (\ s ->
contramap (\ (a, s') -> (f a, a')) (unPRT m s))
{-# INLINE contramap #-}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment