Skip to content

Instantly share code, notes, and snippets.

@edsko
Last active September 10, 2018 15:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edsko/53c17347f4bac5828d4be0a099773f77 to your computer and use it in GitHub Desktop.
Save edsko/53c17347f4bac5828d4be0a099773f77 to your computer and use it in GitHub Desktop.
Applicative-only, spaceleak-free version of 'WriterT'
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
module Main where
import Data.Functor.Identity
import Data.Monoid
import Data.Traversable
import Control.Arrow (first, second)
import Control.Monad.Writer
{-------------------------------------------------------------------------------
Applicative-only, spaceleak-free version of 'WriterT'
-------------------------------------------------------------------------------}
newtype Collect w f a = Collect { runCollect :: f (a, w) }
deriving instance Show (f (a, w)) => Show (Collect w f a)
instance Functor f => Functor (Collect w f) where
fmap f (Collect bcs) = Collect (fmap (first f) bcs)
instance (Applicative f, Monoid w) => Applicative (Collect w f) where
pure x = Collect (pure (x, mempty))
Collect fcs <*> Collect bcs = Collect (aux <$> fcs <*> bcs)
where
-- We force the evaluation of both logs, and tie the evaluation of their
-- concatenation of the pair also, just to be sure to be sure
aux :: (a -> b, w) -> (a, w) -> (b, w)
aux (f, !w) (a, !w') = let !w'' = mappend w w' in (f a, w'')
-- | Walk over a traversable data structure, collecting additional results
traverseCollect :: forall t f a b c. (Traversable t, Applicative f)
=> (a -> f (b, c)) -> t a -> f (t b, [c])
traverseCollect f = runCollect . traverse f'
where
f' :: a -> Collect [c] f b
f' = Collect . fmap (second (:[])) . f
collect :: (Applicative f, Monoid w) => w -> Collect w f a -> Collect w f a
collect w a = Collect (pure (id, w)) <*> a
{-------------------------------------------------------------------------------
Usage example
-------------------------------------------------------------------------------}
-- `traverse` is like `modifyMVar_`:
--
-- > modifyMVar_ :: MVar a -> (a -> IO a) -> IO ()
-- > flip traverse :: t a -> (a -> f b) -> f (t b)
modify_ :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b)
modify_ = traverse
-- now here's the puzzle. can we define the equivalent of modifyMVar?
--
-- modifyMVar :: MVar a -> (a -> IO (a, b)) -> IO b
-- flip ??? :: t a -> (a -> f (b, c)) -> f (t b, [c])
modify :: forall t f a b c. (Traversable t, Applicative f)
=> (a -> f (b, c)) -> t a -> f (t b, [c])
modify f = runCollect . traverse f'
where
f' :: a -> Collect [c] f b
f' = Collect . fmap (second (:[])) . f
{-------------------------------------------------------------------------------
Test
-------------------------------------------------------------------------------}
-- max residency: 31 kB
testCollect :: Collect (Sum Int) Identity ()
testCollect = nTimes 10000000 (collect (Sum 1)) (pure ())
-- max residency: 1.3 GB
testWriterT :: WriterT (Sum Int) Identity ()
testWriterT = nTimesM 10000000 (\() -> tell (Sum 1)) ()
main :: IO ()
main = print testWriterT
{-------------------------------------------------------------------------------
Auxiliary
-------------------------------------------------------------------------------}
nTimes :: Int -> (a -> a) -> (a -> a)
nTimes 0 _ !a = a
nTimes n f !a = nTimes (n - 1) f (f a)
nTimesM :: Monad m => Int -> (a -> m a) -> (a -> m a)
nTimesM 0 _ !a = return a
nTimesM n f !a = f a >>= nTimesM (n - 1) f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment