Skip to content

Instantly share code, notes, and snippets.

@gelisam
Created October 15, 2020 02:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gelisam/23f487ede1d6e2c53a290e11b6a4c6c9 to your computer and use it in GitHub Desktop.
Save gelisam/23f487ede1d6e2c53a290e11b6a4c6c9 to your computer and use it in GitHub Desktop.
optimizing bubble sort using free monads
-- In response to https://twitter.com/chrislpenner/status/1315494889191161857
-- and https://twitter.com/drezil1985/status/1315495791243472896
--
-- The challenge is to implement bubble sort on lists... which makes no sense
-- at first glance, because bubble sort is defined in terms of swap operations,
-- which make little sense for lists!
{-# LANGUAGE DataKinds, FlexibleContexts, GADTs, MultiWayIf, QuantifiedConstraints, ScopedTypeVariables, TypeApplications, TypeOperators #-}
module Main where
import Test.DocTest
import Criterion.Main
import Control.Monad.Extra
import Control.Monad.Freer
import Control.Monad.Freer.State
import Control.Monad.Freer.Writer
import Data.Traversable
-- $setup
-- >>> inputLength = 100 :: Int
-- >>> input = take inputLength (cycle [1..10]) :: [Int]
-- First, let's write a version of the bubble-sort algorithm which only uses
-- swaps, without assuming a particular representation (a list or an array).
data Bubble a r where
ReadAt :: Int -> Bubble a a
SwapAt :: Int -> Bubble a () -- swap indices i and (i+1)
instance Show (Bubble a r) where
show (ReadAt i) = "ReadAt " ++ show i
show (SwapAt i) = "SwapAt " ++ show i
readAt :: Int -> Eff (Bubble a ': effs) a
readAt = send . ReadAt
swapAt :: forall a effs
. Int -> Eff (Bubble a ': effs) ()
swapAt = send . SwapAt @a
bubbleSort :: Ord a
=> Int -> Eff (Bubble a ': effs) ()
bubbleSort n = do
whileM $ do
swapped <- for [0..n-2] $ \i -> do
x <- readAt i
y <- readAt (i+1)
if x > y
then do
swapAt i
pure True
else do
pure False
pure (or swapped)
-- We need to interpret that algorithm so it runs on a list.
--
-- The naive way to do it is to implement a swap operation on lists, but that's
-- quite slow because each swap needs to traverse most of the list.
naiveSwap :: Int -> [a] -> [a]
naiveSwap 0 (x1 : x2 : xs) = x2 : x1 : xs
naiveSwap i (x : xs) = x : naiveSwap (i-1) xs
-- |
-- >>> run . execState input . runBubbleNaive $ bubbleSort inputLength
-- [1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9,10,10,10,10,10,10,10,10,10,10]
runBubbleNaive :: Eff (Bubble a ': effs) r
-> Eff (State [a] ': effs) r
runBubbleNaive = reinterpret go
where
go :: forall a effs r
. Bubble a r -> Eff (State [a] ': effs) r
go (ReadAt i) = (!! i) <$> get
go (SwapAt i) = modify @[a] (naiveSwap i)
-- If we take a closer look at what's going on (which we can do easily, thanks
-- free monads!), we will quickly notice a pattern: we read and swap at mostly
-- incrementing indices! Mostly, because
-- 1. Sometimes we ReadAt (i+1) and then SwapAt i
-- (see example below)
-- 2. The next iteration through the whileM will start at 0 again
-- (not shown below, as it's hundreds more entries later)
-- |
-- >>> mapM_ putStrLn . take 30 . snd . run . runWriter . evalState input . runBubbleNaive . verbosely $ bubbleSort inputLength
-- ReadAt 0
-- ReadAt 1
-- ReadAt 1
-- ReadAt 2
-- ReadAt 2
-- ReadAt 3
-- ReadAt 3
-- ReadAt 4
-- ReadAt 4
-- ReadAt 5
-- ReadAt 5
-- ReadAt 6
-- ReadAt 6
-- ReadAt 7
-- ReadAt 7
-- ReadAt 8
-- ReadAt 8
-- ReadAt 9
-- ReadAt 9
-- ReadAt 10
-- SwapAt 9
-- ReadAt 10
-- ReadAt 11
-- SwapAt 10
-- ReadAt 11
-- ReadAt 12
-- SwapAt 11
-- ReadAt 12
-- ReadAt 13
-- SwapAt 12
verbosely :: forall f r. (forall a. Show (f a))
=> Eff '[f, Writer [String]] r
-> Eff '[f, Writer [String]] r
verbosely = interpose go
where
go :: forall a
. f a -> Eff '[f, Writer [String]] a
go action = do
raise $ tell [show action]
send action
-- Can we exploit that incrementing-indices structure to make the algorithm
-- faster? yes! Let's interpret our Bubble actions into Passes actions, in
-- which there is a current cursor which supports efficient reads and swaps
-- near the cursor, and a less efficient Reset action which resets the cursor
-- to the beginning of the structure (the reinterpreted algorithm is still
-- independent of whether we are using a list or an array). This allows us to
-- combine consecutive list traversals into one.
data Passes a r where
ReadHere :: Passes a a
ReadNext :: Passes a a
WriteHere :: a -> Passes a ()
Next :: Passes a ()
Reset :: Passes a ()
instance Show a => Show (Passes a r) where
show ReadHere = "ReadHere"
show ReadNext = "ReadNext"
show (WriteHere a) = "WriteHere " ++ show a
show Next = "Next"
show Reset = "Reset"
readHere :: Eff (Passes a ': effs) a
readHere = send ReadHere
readNext :: Eff (Passes a ': effs) a
readNext = send ReadNext
writeHere :: a -> Eff (Passes a ': effs) ()
writeHere = send . WriteHere
next :: forall a effs
. Eff (Passes a ': effs) ()
next = send (Next @a)
reset :: forall a effs
. Eff (Passes a ': effs) ()
reset = send (Reset @a)
runBubblePasses :: Eff (Bubble a ': effs) r
-> Eff (Passes a ': effs) r
runBubblePasses = evalState 0 . reinterpret2 go
where
go :: Bubble a r -> Eff (State Int ': Passes a ': effs) r
go (ReadAt i) = do
j <- get
if | i == j -> do raise readHere
| i == j + 1 -> do raise readNext
| i > j + 1 -> do raise next >> modify @Int (+ 1)
go (ReadAt i)
| otherwise -> do raise reset >> put @Int 0
go (ReadAt i)
go (SwapAt i) = do
j <- get
if | i == j -> do x <- raise readHere
y <- raise readNext
raise (writeHere y)
raise next >> modify @Int (+ 1)
raise (writeHere x)
| i > j -> do raise next >> modify @Int (+ 1)
go (SwapAt i)
| otherwise -> do raise reset >> put @Int 0
go (SwapAt i)
-- We can now interpret that modified algorithm so it runs on a list. By using
-- a zipper to keep track of the current location, this time we can finally all
-- the operation (except Reset) efficiently.
-- |
-- >>> run . runPasses input . runBubblePasses $ bubbleSort inputLength
-- [1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9,10,10,10,10,10,10,10,10,10,10]
runPasses :: [a]
-> Eff (Passes a ': effs) r
-> Eff effs [a]
runPasses ys = fmap recombine . execState ([], ys) . reinterpret go
where
recombine :: ([a], [a]) -> [a]
recombine (xs, ys) = reverse xs ++ ys
go :: forall a effs r
. Passes a r -> Eff (State ([a], [a]) ': effs) r
go ReadHere = do
(_, ~(y:_)) <- get @([a], [a])
pure y
go ReadNext = do
(_, ~(_:y:_)) <- get @([a], [a])
pure y
go (WriteHere y) = do
(xs, ~(_:ys)) <- get @([a], [a])
put (xs, y:ys)
go Next = do
(xs, ~(a:ys)) <- get @([a], [a])
put (a:xs, ys)
go Reset = do
s <- get @([a], [a])
put @([a], [a]) ([], recombine s)
-- Previously, the number of traversals was proportional to the number of
-- swaps, which was quadratic in the lenght of the list. This time the number
-- of traversals is the number of times we call Reset (plus 1), which should be
-- about the same as the length of the list (100):
-- |
-- >>> run . execState 1 . runPasses input . countPasses . runBubblePasses $ bubbleSort inputLength
-- 82
countPasses :: Eff (Passes a ': State Int ': effs) r
-> Eff (Passes a ': State Int ': effs) r
countPasses = interpose go
where
go :: Passes a r -> Eff (Passes a ': State Int ': effs) r
go Reset = do
modify @Int (+ 1)
reset
go action = do
send action
-- The reinterpreted algorithm is faster in practice, too; criterion says it's
-- about 4 times faster. For comparison, I also write versions of the naive and
-- optimized algorithm which do not use freer-simple.
-- |
-- >>> naiveNoFreer inputLength input
-- [1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9,10,10,10,10,10,10,10,10,10,10]
naiveNoFreer :: forall a. Ord a
=> Int -> [a] -> [a]
naiveNoFreer n xs = let (xs', changed) = bubbleUp 0 xs
in if changed
then naiveNoFreer n xs'
else xs'
where
bubbleUp :: Int -> [a] -> ([a], Bool)
bubbleUp i xs
| i + 1 < n
= let x = xs !! i
y = xs !! (i+1)
in if x <= y
then bubbleUp (i+1) xs
else (fst . bubbleUp (i+1) $ naiveSwap i xs, True)
| otherwise
= (xs, False)
-- |
-- >>> handOptimized input
-- [1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9,10,10,10,10,10,10,10,10,10,10]
handOptimized :: forall a. Ord a
=> [a] -> [a]
handOptimized xs0 = tryNoop id xs0
where
tryNoop :: ([a] -> [a]) -> [a] -> [a]
tryNoop addPrefix (x:y:ys)
| x <= y
= tryNoop (addPrefix . (x :)) (y:ys)
| otherwise
= handOptimized $ addPrefix (y : bubbleUp x ys)
tryNoop _ _
= xs0
bubbleUp :: a -> [a] -> [a]
bubbleUp x []
= [x]
bubbleUp x (y:ys)
| x <= y
= x : bubbleUp y ys
| otherwise
= y : bubbleUp x ys
benchmark :: ((Int, [Int]) -> [Int]) -> [Benchmark]
benchmark f = [ bench (show n) $ nf f (n, take n (cycle [1..10]))
| n <- [10, 50, 100, 500, 1000]
]
-- On my machine, the result is that runBubbleNaive is by far the slowest,
-- naiveNoFreer is about 5% faster, runBubblePasses is 4x faster, but
-- handOptimized is about 1000x faster than that. Since the speedup from
-- runBubbleNaive to naiveNoFreer is much smaller, I don't think the 1000x
-- factor is due to freer-simple's overhead, but rather, to the fact that
-- runBubblePasses performs its optimization at runtime. That is,
-- runBubblePasses is effectively a runtime interpreter for runBubbleNaive
-- which happens to do some runtime optimizations, whereas handOptimized is
-- more like the result of performing that optimization on runBubbleNaive at
-- compile-time. Plus, since the result of that optimization is expressed as a
-- Haskell program, that optimized code is then further optimized by ghc, which
-- gives it another performance boost; but compiling with -O0 yielded similar
-- proportions, so that didn't seem to be much of a factor this time.
main :: IO ()
main = defaultMain
[ bgroup "runBubbleNaive" $ benchmark $ \(n, xs) -> run . execState xs . runBubbleNaive $ bubbleSort n
, bgroup "runBubblePasses" $ benchmark $ \(n, xs) -> run . runPasses xs . runBubblePasses $ bubbleSort n
, bgroup "naiveNoFreer" $ benchmark $ \(n, xs) -> naiveNoFreer n xs
, bgroup "handOptimized" $ benchmark $ \(_, xs) -> handOptimized xs
]
--main :: IO ()
--main = doctest ["FreeBubbleSort.hs"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment