Skip to content

Instantly share code, notes, and snippets.

@bicycle1885
Created June 10, 2014 14:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bicycle1885/25d7f2ae1c709aa69e64 to your computer and use it in GitHub Desktop.
Save bicycle1885/25d7f2ae1c709aa69e64 to your computer and use it in GitHub Desktop.
minimal stream fusion
{-# LANGUAGE ExistentialQuantification #-}
module Main (main) where
import Prelude hiding (map, filter, foldl, sum)
import System.Environment (getArgs)
data Stream a = forall s. Stream (s -> Step a s) s
data Step a s = Done
| Yield a s
| Skip s
-- | map for Stream
mapS :: (a -> b) -> Stream a -> Stream b
mapS f (Stream next0 s0) = Stream next s0
where
next s = case next0 s of
Done -> Done
Skip s' -> Skip s'
Yield x s' -> Yield (f x) s'
-- | filter for Stream
filterS :: (a -> Bool) -> Stream a -> Stream a
filterS p (Stream next0 s0) = Stream next s0
where
next s = case next0 s of
Done -> Done
Skip s' -> Skip s'
Yield x s' | p x -> Yield x s'
| otherwise -> Skip s'
-- | foldl for Stream
foldlS :: (b -> a -> b) -> b -> Stream a -> b
foldlS f z0 (Stream next0 s0) = go z0 s0
where
go z s = case next0 s of
Done -> z
Skip s' -> go z s'
Yield x s' -> go (z `f` x) s'
stream :: [a] -> Stream a
{-# INLINE [1] stream #-}
stream = Stream next
where
next [] = Done
next (x:xs) = Yield x xs
unstream :: Stream a -> [a]
{-# INLINE [1] unstream #-}
unstream (Stream next0 s0) = unfold s0
where
unfold s = case next0 s of
Done -> []
Skip s' -> unfold s'
Yield x s' -> x : unfold s'
{-# RULES
"stream/unstream" [2] forall s. stream (unstream s) = s
#-}
map :: (a -> b) -> [a] -> [b]
{-# INLINE map #-}
map f = unstream . mapS f . stream
filter :: (a -> Bool) -> [a] -> [a]
{-# INLINE filter #-}
filter p = unstream . filterS p . stream
foldl :: (b -> a -> b) -> b -> [a] -> b
{-# INLINE foldl #-}
foldl f z = foldlS f z . stream
sum :: Num a => [a] -> a
sum = foldl (+) 0
main :: IO ()
main = do
n:d:_ <- fmap (map read) getArgs
let xs = [k * m | k <- [1..n], m <- [1..k]] :: [Int]
-- boring computation
print $ sum . map (\x -> x - d) . filter (> 0) . map (+ d) $ xs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment