Skip to content

Instantly share code, notes, and snippets.

@snoyberg
Created February 7, 2014 04:21
Show Gist options
  • Save snoyberg/8857344 to your computer and use it in GitHub Desktop.
Save snoyberg/8857344 to your computer and use it in GitHub Desktop.
{-# LANGUAGE RankNTypes #-}
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.ST (runST)
import Control.Monad.Trans.State.Strict (execState, put)
import Criterion.Main (bench, bgroup, defaultMain,
whnf, whnfIO)
import qualified Data.ByteString as S
import Data.ByteString.Internal (ByteString (PS),
inlinePerformIO)
import Data.ByteString.Unsafe (unsafeIndex)
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as VM
import Data.Word (Word8)
import Foreign.ForeignPtr (touchForeignPtr)
import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr)
import Foreign.Ptr (plusPtr)
import Foreign.Storable (peek, peekByteOff)
type Freq m = VM.MVector (PrimState m) Int
newFreq :: PrimMonad m => m (Freq m)
newFreq = VM.replicate 256 0
addFreqWord8 :: PrimMonad m => Freq m -> Word8 -> m ()
addFreqWord8 f w = do
let index = fromIntegral w
oldCount <- VM.read f index
VM.write f index (oldCount + 1)
addFreqBS :: PrimMonad m
=> ((Word8 -> m ()) -> S.ByteString -> m ())
-> Freq m -> S.ByteString -> m ()
addFreqBS mapM_BS f = mapM_BS (addFreqWord8 f)
calcFreq :: PrimMonad m
=> ((Word8 -> m ()) -> S.ByteString -> m ())
-> S.ByteString
-> m (V.Vector Int)
calcFreq mapM_BS bs = do
freq <- newFreq
addFreqBS mapM_BS freq bs
V.unsafeFreeze freq
calcFreq' :: (forall m. Monad m => (Word8 -> m ()) -> S.ByteString -> m ())
-> S.ByteString
-> V.Vector Int
calcFreq' mapM_BS bs = runST $ do
freq <- newFreq
addFreqBS mapM_BS freq bs
V.unsafeFreeze freq
lastByte :: (forall m. Monad m => (Word8 -> m ()) -> S.ByteString -> m ())
-> S.ByteString -> Word8
lastByte mapM_BS = flip execState 0 . mapM_BS put
main :: IO ()
main = do
bs <- S.readFile "random"
defaultMain
[ bgroup "IO"
[ bench "indexing" $ whnfIO $ calcFreq mapM_Index bs
, bench "indexing unsafe" $ whnfIO $ calcFreq mapM_IndexUnsafe bs
, bench "indexing/reverse" $ whnfIO $ calcFreq mapM_IndexReverse bs
, bench "direct" $ whnfIO $ calcFreq mapM_Direct bs
, bench "directOff" $ whnfIO $ calcFreq mapM_DirectOff bs
, bench "direct IO" $ whnfIO $ calcFreq mapM_DirectIO bs
, bench "mapM_ . unpack" $ whnfIO $ calcFreq mapM_Unpack bs
]
, bgroup "ST"
[ bench "indexing" $ whnf (calcFreq' mapM_Index) bs
, bench "indexing unsafe" $ whnf (calcFreq' mapM_IndexUnsafe) bs
, bench "indexing/reverse" $ whnf (calcFreq' mapM_IndexReverse) bs
, bench "direct" $ whnf (calcFreq' mapM_Direct) bs
, bench "directOff" $ whnf (calcFreq' mapM_DirectOff) bs
-- , bench "direct IO" $ whnf (calcFreq' mapM_DirectIO) bs
, bench "mapM_ . unpack" $ whnf (calcFreq' mapM_Unpack) bs
]
, bgroup "StateT"
[ bench "indexing" $ whnf (lastByte mapM_Index) bs
, bench "indexing unsafe" $ whnf (lastByte mapM_IndexUnsafe) bs
, bench "indexing/reverse" $ whnf (lastByte mapM_IndexReverse) bs
, bench "direct" $ whnf (lastByte mapM_Direct) bs
, bench "directOff" $ whnf (lastByte mapM_DirectOff) bs
-- , bench "direct IO" $ whnf (calcFreq' mapM_DirectIO) bs
, bench "mapM_ . unpack" $ whnf (lastByte mapM_Unpack) bs
]
]
mapM_Unpack :: Monad m => (Word8 -> m ()) -> ByteString -> m ()
mapM_Unpack f = mapM_ f . S.unpack
mapM_Direct :: Monad m => (Word8 -> m ()) -> ByteString -> m ()
mapM_Direct f (PS fptr offset len) = do
let start = unsafeForeignPtrToPtr fptr `plusPtr` offset
end = start `plusPtr` len
loop ptr
| ptr >= end = inlinePerformIO (touchForeignPtr fptr) `seq` return ()
| otherwise = f (inlinePerformIO (peek ptr)) >> loop (ptr `plusPtr` 1)
loop start
mapM_DirectOff :: Monad m => (Word8 -> m ()) -> ByteString -> m ()
mapM_DirectOff f (PS fptr offset len) = do
let ptr = unsafeForeignPtrToPtr fptr `plusPtr` offset
end = offset + len
loop i
| i >= end = inlinePerformIO (touchForeignPtr fptr) `seq` return ()
| otherwise = f (inlinePerformIO (peekByteOff ptr i)) >> loop (succ i)
loop 0
mapM_DirectIO :: (Word8 -> IO ()) -> ByteString -> IO ()
mapM_DirectIO f (PS fptr offset len) = do
let start = unsafeForeignPtrToPtr fptr `plusPtr` offset
end = start `plusPtr` len
loop ptr
| ptr >= end = touchForeignPtr fptr
| otherwise = peek ptr >>= f >> loop (ptr `plusPtr` 1)
loop start
mapM_Index :: Monad m => (Word8 -> m ()) -> ByteString -> m ()
mapM_Index f bs =
loop 0
where
len = S.length bs
loop i
| i >= len = return ()
| otherwise = do
f $! S.index bs i
loop $! succ i
mapM_IndexUnsafe :: Monad m => (Word8 -> m ()) -> ByteString -> m ()
mapM_IndexUnsafe f bs =
loop 0
where
len = S.length bs
loop i
| i >= len = return ()
| otherwise = do
f $! unsafeIndex bs i
loop $! succ i
mapM_IndexReverse :: Monad m => (Word8 -> m ()) -> ByteString -> m ()
mapM_IndexReverse f bs =
loop (S.length bs)
where
loop 0 = return ()
loop i = do
let i' = i - 1
f $! S.index bs i'
loop i'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment