Skip to content

Instantly share code, notes, and snippets.

@evertedsphere
Last active December 1, 2021 16:50
Show Gist options
  • Save evertedsphere/1e40f75001aaf6764edd6a820661b514 to your computer and use it in GitHub Desktop.
Save evertedsphere/1e40f75001aaf6764edd6a820661b514 to your computer and use it in GitHub Desktop.
monadic mapAccumL adventures
-- a tiny tutorial.
{-# LANGUAGE TypeApplications, ScopedTypeVariables #-}
module Main where
-- read until the end for something that may be a surprise
-- (if you're newer to haskell)
import Control.Monad.State
import Control.Category ((>>>))
-- a coworker of mine asked me for help with this type
--
-- traverseWithMap :: (HashMap a b -> c -> m (d, HashMap a b)) -> [c] -> m [d]
--
-- specifically, how would one implement a function of this type (if it was
-- possible at all)?
--
-- i did also receive an example (and i think perfectly good) implementation
-- using a simple fold. here's how solved the problem before i looked at that
-- solution.
--
-- (excuse the didactic style and the occasional bit of advice; the person i
-- wrote this for is someone i taught haskell earlier this year, and have
-- been mentoring for a while)
--
--------------------------------------------------------------------------------
-- let's first see what parts of this complicated type signature are actually
-- important. generalising type signatures is not always a good idea, but it
-- can help make implementation easier by removing unnecessary details.
--
-- first off, HashMap a b doesn't change throughout the type, so we can replace
-- it with a type variable of its own:
--
-- (h -> c -> m (d, h)) -> [c] -> m [d]
--
-- so this is now traverseWithH :) next, while we're at it, let's take the
-- "traverse" seriously:
--
-- Traversable t => (h -> c -> m (d, h)) -> t c -> m (t d)
--
-- and generalise to any arbitarary traversable functor, not just lists.
--
-- at this point, if we now notice we'd forgotten to add any constraints on m,
-- we can improve this to
--
-- (Monad m, Traversable t) => (h -> c -> m (d, h)) -> t c -> m (t d)
--
-- and if you've seen a certain standard library function...
--
-- mapAccumL :: Traversable t => (s -> a -> (s, b)) -> s -> t a -> (s, t b)
--
-- this looks like a monadic version of mapAccumL!
--
-- (well, with the final `s` removed.)
--
-- knowing this, one can implement a function of this new, more polymorphic,
-- type cleanly in terms of a fold, analogous to how the library function is
-- defined:
--
-- <exercise>
--
-- that's a good place to stop. indeed, while this gist does go on for quite a bit
-- longer, i believe the above really provides the bulk of the
-- widely-applicable lessons one can learn here:
-- (1) simplify types and/or make them more polymorphic when possible
-- (2) try to see what essential properties of the types your code depends on: e.g.
-- here we don't need the fact that t = [], just that it's something traversable
-- (3) almost more important than (1) and (2) is that one must do them *within
-- reason*. use your judgment; maximally generic code is often incomprehensible
-- and occasionally ugly.
-- (4) read around in the standard libraries, use hoogle to find similar
-- standard functions when possible: the implementations of a lot of library
-- functions illustrate useful patterns of working with basic primitives like
-- fold(l|r) and traverse
--
--------------------------------------------------------------------------------
-- however, if you're lucky, you could happen to notice as i did that mapAccumL
-- is sort of "traverse for StateT". what does this mean?
--
-- let's reorder the arguments a bit:
--
-- mapAccumL :: Traversable t => (a -> s -> (b, s) ) -> t a -> s -> (t b, s )
--
-- and add some suggestive parentheses:
--
-- mapAccumL :: Traversable t => (a -> (s -> (b, s))) -> t a -> (s -> (t b, s))
--
-- this is *exactly*
--
-- mapAccumL :: Traversable t => (a -> State s b) -> t a -> State s (t b)
--
-- just with some wrapping and unwrapping of the `State` newtype. it should be clear
-- how to make this "monadic": replace State by StateT m!
--
-- let's work through this.
-- we start with this, which is similar to the type we ended up with above:
f0, f00
:: forall t m s a b. (Monad m, Traversable t)
=> (a -> StateT s m b) -> t a -> StateT s m (t b)
-- using an explicit forall here is required so you can use type applications
-- (and we also need scopedtypevariables to make that work, or else ghc will
-- complain it can't see t or s or m)
f0 = traverse @t @(StateT s m)
-- to be fair, you don't need any of this:
--
f00 = traverse
--
-- is just fine, but explicitly showing which traversable functor and which
-- applicative it's using makes things easier to understand and really comes
-- in handy when the types get more complex than this. it's a very good habit
-- to have in my experience
--
-- (quick question: why is the constraint above `Monad m` and not `Applicative
-- m`, even though `Traversable` only uses an `Applicative f` constraint?)
-- okay, now let's clean this up and turn it into a function of the exact type
-- we started out with.
-- eliding the explicit forall from now on, first let's remove the StateT from
-- the type:
f1, f2
:: (Monad m, Traversable t)
=> (a -> s -> m (b, s)) -> t a -> s -> m (t b, s)
f1 k ta s = runStateT (f0 (\a -> StateT (k a)) ta) s
f2 k ta s = runStateT (f0 (StateT . k) ta) s
-- and reorder the arguments to bring it in line with what we want:
f3, f4, f5, f6
:: (Monad m, Traversable t)
=> (s -> a -> m (b, s)) -> s -> t a -> m (t b, s)
f3 k s ta = f2 (\s' a -> k a s') ta s
f4 k = flip (f2 (flip k)) -- this is hard to read
f5 = flip . f2 . flip -- but we can always do worse!
f6 k s ta = f2 (flip k) ta s -- this is what i'd write personally,
-- and what i feel is easiest to read.
-- striking a balance is important!
-- now drop the final s, since our original function doesn't have it:
f7 :: (Monad m, Traversable t)
=> (s -> a -> m (b, s)) -> s -> t a -> m (t b)
f7 k s ta = fst <$> f2 (flip k) ta s
-- inline f2:
f8, f
:: (Monad m, Traversable t)
=> (s -> a -> m (b, s)) -> s -> t a -> m (t b)
f8 k s ta = fst <$> runStateT (traverse (StateT . flip k) ta) s
-- and, for our final touch, using fst after runStateT like that is a common
-- enough operation there's a name for it:
f k s ta = evalStateT (traverse (StateT . flip k) ta) s
-- that's it!
iter :: Int -> Int -> [(Int, Int)]
iter s a = [(s, a + 1), (s + 1, a)]
main :: IO ()
main = print (f4 iter 0 [1..3])
--------------------------------------------------------------------------------
-- epilogue
--
-- last step: remember what i said about striking a balance?
-- remember flip . f2 . flip?
--
-- since my mind works like this, i'm going to go off on a long tangent now:
-- let's, erm, forget all about balance and writing comprehensible code and
-- do the same thing with our final function, and, yknow,
f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19, f20, f21, f22
:: (Monad m, Traversable t) => (s -> a -> m (b, s)) -> s -> t a -> m (t b)
-- make our code look more functional (tm)!
--
-- by this, i mean "pointfree style" (google it):
--
-- f x = g (h x) --> f = g . h
-- f x y = g (h x y) --> f x = g . h x --> f = (g .) . h
--
-- that is, refactoring our functions to take no explicit arguments.
--
-- let's begin:
f9 k s ta = evalStateT (traverse (StateT . flip k) ta) s
f10 k s ta = evalStateT (traverse (StateT . flip k) $ ta) s
f11 k s ta = flip evalStateT s (traverse (StateT . flip k) $ ta)
-- we're now just applying two functions to the last argument, so we can
-- eta-reduce (google that term):
f12 k s = flip evalStateT s . traverse (StateT . flip k)
-- here (>>>) is just (.) with the arguments reversed; it comes from Control.Category
f13 k s = traverse (StateT . flip k) >>> flip evalStateT s
f14 k = (traverse (StateT . flip k) >>>) . flip evalStateT
-- this is a lot of fun
f15 k = flip evalStateT >>> (traverse (StateT . flip k) >>>)
f16 k = flip evalStateT >>> (. traverse (StateT . flip k))
-- *through gritted teeth* really fun
f17 k = flip evalStateT >>> (. ((traverse . (StateT .) . flip) k))
-- really, uh, teaches you...
f18 k = flip evalStateT >>> ((flip (.) . ((traverse . (StateT .) . flip))) k)
-- ...a lot...
f19 k = ((flip evalStateT >>>) . ((flip (.) . ((traverse . (StateT .) . flip))))) k
-- whew, now replace all the (>>>)s by normal (.)s
f20 = ((flip evalStateT >>>) . ((flip (.) . ((traverse . (StateT .) . flip)))))
f21 = ((. flip evalStateT) . ((flip (.) . ((traverse . (StateT .) . flip)))))
-- and remove all the unnecessary parens
f22 = (. flip evalStateT) . flip (.) . traverse . (StateT .) . flip
-- whew.
--
-- been quite a few years since i last did that.
--
-- okay, jokes aside, this kind of thing does make for a nice puzzle and can be
-- pretty educational and so on (in terms of teaching you to follow the types
-- of small expressions inside a large complicated expression) if you enjoy it.
--
-- if you're curious, google "pointfree style". i think there's also a website
-- that does this for you (there's a theorem that you can do this for any
-- function or something, too, and you can get a taste for what an algorithm
-- that does this would look like by following the steps above; repeatedly
-- moving the last argument to the end of the body of the function)
--
-- just ... don't do this in code anyone else will read (which includes you in two
-- months)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment