Created
October 15, 2020 02:59
-
-
Save gelisam/23f487ede1d6e2c53a290e11b6a4c6c9 to your computer and use it in GitHub Desktop.
optimizing bubble sort using free monads
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- In response to https://twitter.com/chrislpenner/status/1315494889191161857 | |
-- and https://twitter.com/drezil1985/status/1315495791243472896 | |
-- | |
-- The challenge is to implement bubble sort on lists... which makes no sense | |
-- at first glance, because bubble sort is defined in terms of swap operations, | |
-- which make little sense for lists! | |
{-# LANGUAGE DataKinds, FlexibleContexts, GADTs, MultiWayIf, QuantifiedConstraints, ScopedTypeVariables, TypeApplications, TypeOperators #-} | |
module Main where | |
import Test.DocTest | |
import Criterion.Main | |
import Control.Monad.Extra | |
import Control.Monad.Freer | |
import Control.Monad.Freer.State | |
import Control.Monad.Freer.Writer | |
import Data.Traversable | |
-- $setup | |
-- >>> inputLength = 100 :: Int | |
-- >>> input = take inputLength (cycle [1..10]) :: [Int] | |
-- First, let's write a version of the bubble-sort algorithm which only uses | |
-- swaps, without assuming a particular representation (a list or an array). | |
data Bubble a r where | |
ReadAt :: Int -> Bubble a a | |
SwapAt :: Int -> Bubble a () -- swap indices i and (i+1) | |
instance Show (Bubble a r) where | |
show (ReadAt i) = "ReadAt " ++ show i | |
show (SwapAt i) = "SwapAt " ++ show i | |
readAt :: Int -> Eff (Bubble a ': effs) a | |
readAt = send . ReadAt | |
swapAt :: forall a effs | |
. Int -> Eff (Bubble a ': effs) () | |
swapAt = send . SwapAt @a | |
bubbleSort :: Ord a | |
=> Int -> Eff (Bubble a ': effs) () | |
bubbleSort n = do | |
whileM $ do | |
swapped <- for [0..n-2] $ \i -> do | |
x <- readAt i | |
y <- readAt (i+1) | |
if x > y | |
then do | |
swapAt i | |
pure True | |
else do | |
pure False | |
pure (or swapped) | |
-- We need to interpret that algorithm so it runs on a list. | |
-- | |
-- The naive way to do it is to implement a swap operation on lists, but that's | |
-- quite slow because each swap needs to traverse most of the list. | |
naiveSwap :: Int -> [a] -> [a] | |
naiveSwap 0 (x1 : x2 : xs) = x2 : x1 : xs | |
naiveSwap i (x : xs) = x : naiveSwap (i-1) xs | |
-- | | |
-- >>> run . execState input . runBubbleNaive $ bubbleSort inputLength | |
-- [1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9,10,10,10,10,10,10,10,10,10,10] | |
runBubbleNaive :: Eff (Bubble a ': effs) r | |
-> Eff (State [a] ': effs) r | |
runBubbleNaive = reinterpret go | |
where | |
go :: forall a effs r | |
. Bubble a r -> Eff (State [a] ': effs) r | |
go (ReadAt i) = (!! i) <$> get | |
go (SwapAt i) = modify @[a] (naiveSwap i) | |
-- If we take a closer look at what's going on (which we can do easily, thanks | |
-- free monads!), we will quickly notice a pattern: we read and swap at mostly | |
-- incrementing indices! Mostly, because | |
-- 1. Sometimes we ReadAt (i+1) and then SwapAt i | |
-- (see example below) | |
-- 2. The next iteration through the whileM will start at 0 again | |
-- (not shown below, as it's hundreds more entries later) | |
-- | | |
-- >>> mapM_ putStrLn . take 30 . snd . run . runWriter . evalState input . runBubbleNaive . verbosely $ bubbleSort inputLength | |
-- ReadAt 0 | |
-- ReadAt 1 | |
-- ReadAt 1 | |
-- ReadAt 2 | |
-- ReadAt 2 | |
-- ReadAt 3 | |
-- ReadAt 3 | |
-- ReadAt 4 | |
-- ReadAt 4 | |
-- ReadAt 5 | |
-- ReadAt 5 | |
-- ReadAt 6 | |
-- ReadAt 6 | |
-- ReadAt 7 | |
-- ReadAt 7 | |
-- ReadAt 8 | |
-- ReadAt 8 | |
-- ReadAt 9 | |
-- ReadAt 9 | |
-- ReadAt 10 | |
-- SwapAt 9 | |
-- ReadAt 10 | |
-- ReadAt 11 | |
-- SwapAt 10 | |
-- ReadAt 11 | |
-- ReadAt 12 | |
-- SwapAt 11 | |
-- ReadAt 12 | |
-- ReadAt 13 | |
-- SwapAt 12 | |
verbosely :: forall f r. (forall a. Show (f a)) | |
=> Eff '[f, Writer [String]] r | |
-> Eff '[f, Writer [String]] r | |
verbosely = interpose go | |
where | |
go :: forall a | |
. f a -> Eff '[f, Writer [String]] a | |
go action = do | |
raise $ tell [show action] | |
send action | |
-- Can we exploit that incrementing-indices structure to make the algorithm | |
-- faster? yes! Let's interpret our Bubble actions into Passes actions, in | |
-- which there is a current cursor which supports efficient reads and swaps | |
-- near the cursor, and a less efficient Reset action which resets the cursor | |
-- to the beginning of the structure (the reinterpreted algorithm is still | |
-- independent of whether we are using a list or an array). This allows us to | |
-- combine consecutive list traversals into one. | |
data Passes a r where | |
ReadHere :: Passes a a | |
ReadNext :: Passes a a | |
WriteHere :: a -> Passes a () | |
Next :: Passes a () | |
Reset :: Passes a () | |
instance Show a => Show (Passes a r) where | |
show ReadHere = "ReadHere" | |
show ReadNext = "ReadNext" | |
show (WriteHere a) = "WriteHere " ++ show a | |
show Next = "Next" | |
show Reset = "Reset" | |
readHere :: Eff (Passes a ': effs) a | |
readHere = send ReadHere | |
readNext :: Eff (Passes a ': effs) a | |
readNext = send ReadNext | |
writeHere :: a -> Eff (Passes a ': effs) () | |
writeHere = send . WriteHere | |
next :: forall a effs | |
. Eff (Passes a ': effs) () | |
next = send (Next @a) | |
reset :: forall a effs | |
. Eff (Passes a ': effs) () | |
reset = send (Reset @a) | |
runBubblePasses :: Eff (Bubble a ': effs) r | |
-> Eff (Passes a ': effs) r | |
runBubblePasses = evalState 0 . reinterpret2 go | |
where | |
go :: Bubble a r -> Eff (State Int ': Passes a ': effs) r | |
go (ReadAt i) = do | |
j <- get | |
if | i == j -> do raise readHere | |
| i == j + 1 -> do raise readNext | |
| i > j + 1 -> do raise next >> modify @Int (+ 1) | |
go (ReadAt i) | |
| otherwise -> do raise reset >> put @Int 0 | |
go (ReadAt i) | |
go (SwapAt i) = do | |
j <- get | |
if | i == j -> do x <- raise readHere | |
y <- raise readNext | |
raise (writeHere y) | |
raise next >> modify @Int (+ 1) | |
raise (writeHere x) | |
| i > j -> do raise next >> modify @Int (+ 1) | |
go (SwapAt i) | |
| otherwise -> do raise reset >> put @Int 0 | |
go (SwapAt i) | |
-- We can now interpret that modified algorithm so it runs on a list. By using | |
-- a zipper to keep track of the current location, this time we can finally all | |
-- the operation (except Reset) efficiently. | |
-- | | |
-- >>> run . runPasses input . runBubblePasses $ bubbleSort inputLength | |
-- [1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9,10,10,10,10,10,10,10,10,10,10] | |
runPasses :: [a] | |
-> Eff (Passes a ': effs) r | |
-> Eff effs [a] | |
runPasses ys = fmap recombine . execState ([], ys) . reinterpret go | |
where | |
recombine :: ([a], [a]) -> [a] | |
recombine (xs, ys) = reverse xs ++ ys | |
go :: forall a effs r | |
. Passes a r -> Eff (State ([a], [a]) ': effs) r | |
go ReadHere = do | |
(_, ~(y:_)) <- get @([a], [a]) | |
pure y | |
go ReadNext = do | |
(_, ~(_:y:_)) <- get @([a], [a]) | |
pure y | |
go (WriteHere y) = do | |
(xs, ~(_:ys)) <- get @([a], [a]) | |
put (xs, y:ys) | |
go Next = do | |
(xs, ~(a:ys)) <- get @([a], [a]) | |
put (a:xs, ys) | |
go Reset = do | |
s <- get @([a], [a]) | |
put @([a], [a]) ([], recombine s) | |
-- Previously, the number of traversals was proportional to the number of | |
-- swaps, which was quadratic in the lenght of the list. This time the number | |
-- of traversals is the number of times we call Reset (plus 1), which should be | |
-- about the same as the length of the list (100): | |
-- | | |
-- >>> run . execState 1 . runPasses input . countPasses . runBubblePasses $ bubbleSort inputLength | |
-- 82 | |
countPasses :: Eff (Passes a ': State Int ': effs) r | |
-> Eff (Passes a ': State Int ': effs) r | |
countPasses = interpose go | |
where | |
go :: Passes a r -> Eff (Passes a ': State Int ': effs) r | |
go Reset = do | |
modify @Int (+ 1) | |
reset | |
go action = do | |
send action | |
-- The reinterpreted algorithm is faster in practice, too; criterion says it's | |
-- about 4 times faster. For comparison, I also write versions of the naive and | |
-- optimized algorithm which do not use freer-simple. | |
-- | | |
-- >>> naiveNoFreer inputLength input | |
-- [1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9,10,10,10,10,10,10,10,10,10,10] | |
naiveNoFreer :: forall a. Ord a | |
=> Int -> [a] -> [a] | |
naiveNoFreer n xs = let (xs', changed) = bubbleUp 0 xs | |
in if changed | |
then naiveNoFreer n xs' | |
else xs' | |
where | |
bubbleUp :: Int -> [a] -> ([a], Bool) | |
bubbleUp i xs | |
| i + 1 < n | |
= let x = xs !! i | |
y = xs !! (i+1) | |
in if x <= y | |
then bubbleUp (i+1) xs | |
else (fst . bubbleUp (i+1) $ naiveSwap i xs, True) | |
| otherwise | |
= (xs, False) | |
-- | | |
-- >>> handOptimized input | |
-- [1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9,10,10,10,10,10,10,10,10,10,10] | |
handOptimized :: forall a. Ord a | |
=> [a] -> [a] | |
handOptimized xs0 = tryNoop id xs0 | |
where | |
tryNoop :: ([a] -> [a]) -> [a] -> [a] | |
tryNoop addPrefix (x:y:ys) | |
| x <= y | |
= tryNoop (addPrefix . (x :)) (y:ys) | |
| otherwise | |
= handOptimized $ addPrefix (y : bubbleUp x ys) | |
tryNoop _ _ | |
= xs0 | |
bubbleUp :: a -> [a] -> [a] | |
bubbleUp x [] | |
= [x] | |
bubbleUp x (y:ys) | |
| x <= y | |
= x : bubbleUp y ys | |
| otherwise | |
= y : bubbleUp x ys | |
benchmark :: ((Int, [Int]) -> [Int]) -> [Benchmark] | |
benchmark f = [ bench (show n) $ nf f (n, take n (cycle [1..10])) | |
| n <- [10, 50, 100, 500, 1000] | |
] | |
-- On my machine, the result is that runBubbleNaive is by far the slowest, | |
-- naiveNoFreer is about 5% faster, runBubblePasses is 4x faster, but | |
-- handOptimized is about 1000x faster than that. Since the speedup from | |
-- runBubbleNaive to naiveNoFreer is much smaller, I don't think the 1000x | |
-- factor is due to freer-simple's overhead, but rather, to the fact that | |
-- runBubblePasses performs its optimization at runtime. That is, | |
-- runBubblePasses is effectively a runtime interpreter for runBubbleNaive | |
-- which happens to do some runtime optimizations, whereas handOptimized is | |
-- more like the result of performing that optimization on runBubbleNaive at | |
-- compile-time. Plus, since the result of that optimization is expressed as a | |
-- Haskell program, that optimized code is then further optimized by ghc, which | |
-- gives it another performance boost; but compiling with -O0 yielded similar | |
-- proportions, so that didn't seem to be much of a factor this time. | |
main :: IO () | |
main = defaultMain | |
[ bgroup "runBubbleNaive" $ benchmark $ \(n, xs) -> run . execState xs . runBubbleNaive $ bubbleSort n | |
, bgroup "runBubblePasses" $ benchmark $ \(n, xs) -> run . runPasses xs . runBubblePasses $ bubbleSort n | |
, bgroup "naiveNoFreer" $ benchmark $ \(n, xs) -> naiveNoFreer n xs | |
, bgroup "handOptimized" $ benchmark $ \(_, xs) -> handOptimized xs | |
] | |
--main :: IO () | |
--main = doctest ["FreeBubbleSort.hs"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment