Last active
April 4, 2021 13:02
-
-
Save iamahuman/65aaa57ef6e9dbfb47a26a898d37a082 to your computer and use it in GitHub Desktop.
A pure ST monad transformer with coercible references
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_HADDOCK show-extensions #-} | |
{-# LANGUAGE CPP #-} | |
{-# LANGUAGE Rank2Types #-} | |
{-# LANGUAGE ExistentialQuantification #-} | |
----------------------------------------------------------------------------- | |
-- | | |
-- Module : Data.PureRef | |
-- Copyright : (c) Jinoh Kang, 2019 | |
-- License : MIT | |
-- | |
-- Maintainer : unmaintained | |
-- Stability : experimental | |
-- Portability : portable | |
-- | |
-- A pure version of the 'ST' Monad (partial). | |
-- | |
-- Do not use this in production. Use 'Data.STRef' instead. | |
-- | |
----------------------------------------------------------------------------- | |
module Data.PureRef ( | |
PRT, PR, PRef, | |
newPRef, readPRef, writePRef, deletePRef, | |
modifyPRef, modifyPRef', | |
runPRT, runPR | |
) where | |
#ifndef MIN_VERSION_base | |
#define CMP_VER_3(x,y,z,a,b,c) \ | |
((x)<(a)||((x)==(a)&&((y)<(b)||((y)==(b)&&((z)<=(c)))))) | |
#define CMP_VER_GHC(x,y,z) CMP_VER_3(x,y,z,__GLASGOW_HASKELL__,__GLASGOW_HASKELL_PATCHLEVEL1__,__GLASGOW_HASKELL_PATCHLEVEL2__) | |
#ifndef __GLASGOW_HASKELL__ | |
#define MIN_VERSION_base(x,y,z) 0 | |
#elif CMP_VER_GHC(806,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,12,0) | |
#elif CMP_VER_GHC(804,2,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,11,1) | |
#elif CMP_VER_GHC(804,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,11,0) | |
#elif CMP_VER_GHC(802,2,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,10,1) | |
#elif CMP_VER_GHC(802,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,10,0) | |
#elif CMP_VER_GHC(800,2,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,9,1) | |
#elif CMP_VER_GHC(800,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,9,0) | |
#elif CMP_VER_GHC(710,3,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,8,2) | |
#elif CMP_VER_GHC(710,2,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,8,1) | |
#elif CMP_VER_GHC(710,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,8,0) | |
#elif CMP_VER_GHC(708,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,7,0) | |
#elif CMP_VER_GHC(706,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,6,0) | |
#elif CMP_VER_GHC(704,2,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,5,1) | |
#elif CMP_VER_GHC(704,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,5,0) | |
#elif CMP_VER_GHC(702,2,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,4,1) | |
#elif CMP_VER_GHC(702,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,4,0) | |
#elif CMP_VER_GHC(700,2,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,3,1) | |
#elif CMP_VER_GHC(700,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,3,0) | |
#elif CMP_VER_GHC(612,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,2,0) | |
#elif CMP_VER_GHC(610,2,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,1,0) | |
#elif CMP_VER_GHC(610,1,0) | |
#define MIN_VERSION_base(x,y,z) CMP_VER_3(x,y,z,4,0,0) | |
#else | |
#define MIN_VERSION_base(x,y,z) 0 | |
#endif | |
#endif | |
import Control.Applicative (Applicative, Alternative, pure, | |
(<*>), (<*), (*>), liftA2, | |
empty, (<|>), some, many) | |
import Control.Monad (Monad, MonadPlus, return, (>>=), | |
fail, mzero, mplus) | |
import Control.Monad.Fix (MonadFix, mfix) | |
#if MIN_VERSION_base(4,9,0) | |
import qualified Control.Monad.Fail as Fail (MonadFail, fail) | |
#endif | |
import Control.Monad.IO.Class (MonadIO, liftIO) | |
import Control.Monad.Trans (MonadTrans, lift) | |
#if MIN_VERSION_base(4,7,0) | |
import Data.Coerce (Coercible, coerce) | |
#endif | |
import Data.Dynamic (Dynamic, fromDynamic, toDyn) | |
import Data.Functor (Functor) | |
import Data.Functor.Identity (Identity, runIdentity) | |
#if MIN_VERSION_base(4,12,0) | |
import Data.Functor.Contravariant (Contravariant, contramap) | |
#endif | |
import Data.Typeable (Typeable) | |
import qualified Data.IntMap.Lazy as M | |
modname :: String | |
modname = "Data.PureRef" | |
{-# INLINE modname #-} | |
type St = M.IntMap Dynamic | |
-- | A pure strict state-transformer monad transformer. | |
-- A computation of type @'PR' s a@ transforms an internal state indexed | |
-- by @s@, and returns a value of type @a@. | |
-- The @s@ parameter serves to keep the internal states | |
-- of different invocations of 'runPR' separate from each other. | |
-- | |
-- The '>>=' and '>>' operations are strict in the state (though not in | |
-- values stored in the state). For example, | |
-- | |
-- @'runPR' ('writePRef' _|_ v >>= f) = _|_@ | |
newtype PRT s m a = PRT { unPRT :: St -> m (a, St) } | |
-- | A pure strict state-transformer monad. | |
type PR s = PRT s Identity | |
newtype CK a b = CK M.Key | |
-- | a value of type @PRef s a@ is a mutable variable | |
-- in state thread @s@, containing a value of type @a@ | |
#if MIN_VERSION_base(4,7,0) | |
data PRef s a = forall b. (Coercible a b, Typeable b) => | |
PRef {-# UNPACK #-}!(CK a b) | |
#else | |
data PRef s a = Typeable a => PRef {-# UNPACK #-}!(CK a a) | |
#endif | |
instance Eq (PRef s a) where | |
PRef (CK x) == PRef (CK y) = x == y | |
{-# INLINE (==) #-} | |
PRef (CK x) /= PRef (CK y) = x /= y | |
{-# INLINE (/=) #-} | |
#if MIN_VERSION_base(4,7,0) | |
cab :: Coercible a b => CK a b -> a -> b | |
cab _ = coerce | |
{-# INLINE cab #-} | |
cba :: Coercible a b => CK a b -> b -> a | |
cba _ = coerce | |
{-# INLINE cba #-} | |
#endif | |
type MaybeS = Either String | |
ro :: Monad m => CK a b | |
-> (Maybe Dynamic -> MaybeS a) | |
-> PRT s m a | |
ro (CK k) f = PRT g where | |
g s = case f (M.lookup k s) of | |
Left e -> fail e | |
Right v -> return (v, s) | |
rw :: Monad m => CK a b | |
-> (Maybe Dynamic -> MaybeS (Maybe Dynamic)) | |
-> PRT s m () | |
rw (CK k) f = PRT g where | |
g s = case M.alterF f k s of | |
Left e -> fail e | |
Right v -> v `seq` return ((), v) | |
pfail :: String -> String -> MaybeS a | |
pfail n s = Left $ modname ++ ('.':n ++ (':':' ':s)) | |
modify :: String | |
-> (CK x y -> (Maybe Dynamic -> MaybeS b) -> a) | |
-> (Dynamic -> MaybeS b) | |
-> (CK x y -> a) | |
modify name mode fn ck@(CK k) = mode ck f where | |
f (Just x) = fn x | |
f _ = pfail name $ "Lost reference to key " ++ show k | |
#if MIN_VERSION_base(4,7,0) | |
oper :: (Typeable y, Coercible x y) => String | |
-> (CK x y -> (Maybe Dynamic -> MaybeS b) -> a) | |
-> (x -> b) | |
-> (CK x y -> a) | |
#else | |
oper :: Typeable x => String | |
-> (CK x x -> (Maybe Dynamic -> MaybeS b) -> a) | |
-> (x -> b) | |
-> (CK x x -> a) | |
#endif | |
oper name mode fn ck@(CK k) = modify name mode f ck where | |
f d = case fromDynamic d of | |
Just bv -> Right (fn x) where | |
#if MIN_VERSION_base(4,7,0) | |
x = cba ck bv | |
#else | |
x = bv | |
#endif | |
_ -> pfail name $ "Unexpected type for " ++ | |
shows k (':':' ':show d) | |
#if MIN_VERSION_base(4,7,0) | |
dg :: (Coercible a b, Typeable b) => | |
CK a b -> a -> Maybe Dynamic | |
dg ck x = Just (toDyn (cab ck x)) | |
#else | |
dg :: Typeable a => CK a a -> a -> Maybe Dynamic | |
dg _ x = Just (toDyn x) | |
#endif | |
{-# INLINE dg #-} | |
stub :: a -> MaybeS (Maybe Dynamic) | |
stub _ = Right $ Just $ errorWithoutStackTrace $ modname ++ | |
": Attempted to use deleted pure reference" | |
{-# INLINE stub #-} | |
-- | /O(log n)/. Build a new 'PRef' in the current state thread. | |
-- | |
-- If the @'Typeable' a@ constraint bothers you, | |
-- there is always 'Data.STRef'. | |
newPRef :: (Monad m, Typeable a) => a -> PRT s m (PRef s a) | |
newPRef v = PRT $ \ s -> | |
let k = if M.null s then 0 else 1 + fst (M.findMax s) | |
r = PRef (CK k :: Typeable a => CK a a) | |
s' = M.insert k (toDyn v) s | |
e = error $ modname ++ ": out of references" | |
in if k < 0 then e else r `seq` s' `seq` return (r, s') | |
-- | /O(log n)/. Read the value of a 'PRef'. | |
readPRef :: Monad m => PRef s a -> PRT s m a | |
readPRef (PRef c) = oper "readPRef" ro (\ x -> x) c | |
-- | /O(log n)/. Write a new value into a 'PRef'. | |
writePRef :: Monad m => PRef s a -> a -> PRT s m () | |
writePRef (PRef c) v = oper "writePRef" rw (\ _ -> dg c v) c | |
-- | /O(log n)/. Mutate the contents of a 'PRef'. | |
modifyPRef :: Monad m => PRef s a -> (a -> a) -> PRT s m () | |
modifyPRef (PRef c) f = oper "modifyPRef" rw (dg c . f) c | |
-- | /O(log n)/. Strict version of 'modifyPRef'. | |
modifyPRef' :: Monad m => PRef s a -> (a -> a) -> PRT s m () | |
modifyPRef' (PRef c) f = oper "modifyPRef'" rw ((dg c $!) . f) c | |
-- | /O(log n)/. Revoke a 'PRef' so that future references to it | |
-- are no longer valid. | |
-- | |
-- Deleteing an already deleted reference does nothing, | |
-- meaning @('>>' 'deletePRef' a)@ is idempotent. | |
-- | |
-- Different 'PRef's are never equal to each other regardless of | |
-- whether any of them is deleted or not. | |
-- | |
-- All references, even deleted, stay (and thus leak) | |
-- forever in the PRT monad. If this bothers you, | |
-- there is always 'Data.STRef'. | |
deletePRef :: Monad m => PRef s a -> PRT s m () | |
deletePRef (PRef c) = modify "deletePRef" rw stub c | |
-- | Return the value computed by a state transformer computation. | |
-- The @forall@ ensures that the internal state used by the 'PRT' | |
-- computation is inaccessible to the rest of the program. | |
runPRT :: Monad m => (forall s. PRT s m a) -> m a | |
runPRT (PRT t) = t M.empty >>= return . fst | |
{-# INLINE runPRT #-} | |
-- | Return the value computed by a state transformer computation. | |
-- The @forall@ ensures that the internal state used by the 'PR' | |
-- computation is inaccessible to the rest of the program. | |
runPR :: (forall s. PR s a) -> a | |
runPR (PRT t) = fst (runIdentity (t M.empty)) | |
{-# INLINE runPR #-} | |
instance Functor f => Functor (PRT s f) where | |
fmap f (PRT m) = PRT g where | |
g s = fmap (\ ~(a, s') -> (f a, s')) (m s) | |
{-# INLINE fmap #-} | |
instance (Functor m, Monad m) => Applicative (PRT s m) where | |
pure a = PRT (\ s -> return (a, s)) | |
{-# INLINE pure #-} | |
PRT mf <*> PRT mx = PRT g where | |
g s = mf s >>= \ ~(f, s') -> | |
fmap (\ ~(x, s'') -> (f x, s'')) (mx s') | |
{-# INLINE (<*>) #-} | |
ma *> mb = ma >>= \ _ -> mb | |
{-# INLINE (*>) #-} | |
PRT ma <* PRT mb = PRT g where | |
g s = ma s >>= \ ~(a, s') -> | |
fmap (\ ~(_, s'') -> (a, s'')) (mb s') | |
{-# INLINE (<*) #-} | |
liftA2 f (PRT mx) (PRT my) = PRT g where | |
g s = mx s >>= \ ~(x, s') -> | |
fmap (\ ~(y, s'') -> (f x y, s'')) (my s') | |
{-# INLINE liftA2 #-} | |
instance Monad m => Monad (PRT s m) where | |
#if !MIN_VERSION_base(4,8,0) | |
return a = PRT (\ s -> (a, s)) | |
{-# INLINE return #-} | |
#endif | |
PRT m >>= f = PRT g where | |
g s = m s >>= \ ~(a, s') -> unPRT (f a) s' | |
{-# INLINE (>>=) #-} | |
#if !MIN_VERSION_base(4,13,0) | |
fail s = PRT (\ _ -> fail s) | |
{-# INLINE fail #-} | |
#endif | |
instance (Monad m, Alternative m) => Alternative (PRT s m) where | |
empty = PRT (\ _ -> empty) | |
{-# INLINE empty #-} | |
PRT m <|> PRT n = PRT (\ s -> m s <|> n s) | |
{-# INLINE (<|>) #-} | |
#if MIN_VERSION_base(4,9,0) | |
instance Fail.MonadFail m => Fail.MonadFail (PRT s m) where | |
fail s = PRT (\ _ -> Fail.fail s) | |
{-# INLINE fail #-} | |
#endif | |
instance MonadPlus m => MonadPlus (PRT s m) where | |
mzero = PRT (\ _ -> mzero) | |
{-# INLINE mzero #-} | |
PRT m `mplus` PRT n = PRT (\ s -> m s `mplus` n s) | |
{-# INLINE mplus #-} | |
instance MonadFix m => MonadFix (PRT s m) where | |
mfix f = PRT (\ s -> mfix (\ ~(a, _) -> unPRT (f a) s)) | |
{-# INLINE mfix #-} | |
instance MonadTrans (PRT s) where | |
lift m = PRT (\ s -> m >>= \ a -> return (a, s)) | |
{-# INLINE lift #-} | |
instance MonadIO m => MonadIO (PRT s m) where | |
liftIO = lift . liftIO | |
{-# INLINE liftIO #-} | |
#if MIN_VERSION_base(4,12,0) | |
instance Contravariant m => Contravariant (PRT s m) where | |
contramap f m = PRT (\ s -> | |
contramap (\ (a, s') -> (f a, a')) (unPRT m s)) | |
{-# INLINE contramap #-} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment