Created
February 7, 2014 04:21
-
-
Save snoyberg/8857344 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE RankNTypes #-} | |
import Control.Monad.Primitive (PrimMonad, PrimState) | |
import Control.Monad.ST (runST) | |
import Control.Monad.Trans.State.Strict (execState, put) | |
import Criterion.Main (bench, bgroup, defaultMain, | |
whnf, whnfIO) | |
import qualified Data.ByteString as S | |
import Data.ByteString.Internal (ByteString (PS), | |
inlinePerformIO) | |
import Data.ByteString.Unsafe (unsafeIndex) | |
import qualified Data.Vector.Unboxed as V | |
import qualified Data.Vector.Unboxed.Mutable as VM | |
import Data.Word (Word8) | |
import Foreign.ForeignPtr (touchForeignPtr) | |
import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) | |
import Foreign.Ptr (plusPtr) | |
import Foreign.Storable (peek, peekByteOff) | |
type Freq m = VM.MVector (PrimState m) Int | |
newFreq :: PrimMonad m => m (Freq m) | |
newFreq = VM.replicate 256 0 | |
addFreqWord8 :: PrimMonad m => Freq m -> Word8 -> m () | |
addFreqWord8 f w = do | |
let index = fromIntegral w | |
oldCount <- VM.read f index | |
VM.write f index (oldCount + 1) | |
addFreqBS :: PrimMonad m | |
=> ((Word8 -> m ()) -> S.ByteString -> m ()) | |
-> Freq m -> S.ByteString -> m () | |
addFreqBS mapM_BS f = mapM_BS (addFreqWord8 f) | |
calcFreq :: PrimMonad m | |
=> ((Word8 -> m ()) -> S.ByteString -> m ()) | |
-> S.ByteString | |
-> m (V.Vector Int) | |
calcFreq mapM_BS bs = do | |
freq <- newFreq | |
addFreqBS mapM_BS freq bs | |
V.unsafeFreeze freq | |
calcFreq' :: (forall m. Monad m => (Word8 -> m ()) -> S.ByteString -> m ()) | |
-> S.ByteString | |
-> V.Vector Int | |
calcFreq' mapM_BS bs = runST $ do | |
freq <- newFreq | |
addFreqBS mapM_BS freq bs | |
V.unsafeFreeze freq | |
lastByte :: (forall m. Monad m => (Word8 -> m ()) -> S.ByteString -> m ()) | |
-> S.ByteString -> Word8 | |
lastByte mapM_BS = flip execState 0 . mapM_BS put | |
main :: IO () | |
main = do | |
bs <- S.readFile "random" | |
defaultMain | |
[ bgroup "IO" | |
[ bench "indexing" $ whnfIO $ calcFreq mapM_Index bs | |
, bench "indexing unsafe" $ whnfIO $ calcFreq mapM_IndexUnsafe bs | |
, bench "indexing/reverse" $ whnfIO $ calcFreq mapM_IndexReverse bs | |
, bench "direct" $ whnfIO $ calcFreq mapM_Direct bs | |
, bench "directOff" $ whnfIO $ calcFreq mapM_DirectOff bs | |
, bench "direct IO" $ whnfIO $ calcFreq mapM_DirectIO bs | |
, bench "mapM_ . unpack" $ whnfIO $ calcFreq mapM_Unpack bs | |
] | |
, bgroup "ST" | |
[ bench "indexing" $ whnf (calcFreq' mapM_Index) bs | |
, bench "indexing unsafe" $ whnf (calcFreq' mapM_IndexUnsafe) bs | |
, bench "indexing/reverse" $ whnf (calcFreq' mapM_IndexReverse) bs | |
, bench "direct" $ whnf (calcFreq' mapM_Direct) bs | |
, bench "directOff" $ whnf (calcFreq' mapM_DirectOff) bs | |
-- , bench "direct IO" $ whnf (calcFreq' mapM_DirectIO) bs | |
, bench "mapM_ . unpack" $ whnf (calcFreq' mapM_Unpack) bs | |
] | |
, bgroup "StateT" | |
[ bench "indexing" $ whnf (lastByte mapM_Index) bs | |
, bench "indexing unsafe" $ whnf (lastByte mapM_IndexUnsafe) bs | |
, bench "indexing/reverse" $ whnf (lastByte mapM_IndexReverse) bs | |
, bench "direct" $ whnf (lastByte mapM_Direct) bs | |
, bench "directOff" $ whnf (lastByte mapM_DirectOff) bs | |
-- , bench "direct IO" $ whnf (calcFreq' mapM_DirectIO) bs | |
, bench "mapM_ . unpack" $ whnf (lastByte mapM_Unpack) bs | |
] | |
] | |
mapM_Unpack :: Monad m => (Word8 -> m ()) -> ByteString -> m () | |
mapM_Unpack f = mapM_ f . S.unpack | |
mapM_Direct :: Monad m => (Word8 -> m ()) -> ByteString -> m () | |
mapM_Direct f (PS fptr offset len) = do | |
let start = unsafeForeignPtrToPtr fptr `plusPtr` offset | |
end = start `plusPtr` len | |
loop ptr | |
| ptr >= end = inlinePerformIO (touchForeignPtr fptr) `seq` return () | |
| otherwise = f (inlinePerformIO (peek ptr)) >> loop (ptr `plusPtr` 1) | |
loop start | |
mapM_DirectOff :: Monad m => (Word8 -> m ()) -> ByteString -> m () | |
mapM_DirectOff f (PS fptr offset len) = do | |
let ptr = unsafeForeignPtrToPtr fptr `plusPtr` offset | |
end = offset + len | |
loop i | |
| i >= end = inlinePerformIO (touchForeignPtr fptr) `seq` return () | |
| otherwise = f (inlinePerformIO (peekByteOff ptr i)) >> loop (succ i) | |
loop 0 | |
mapM_DirectIO :: (Word8 -> IO ()) -> ByteString -> IO () | |
mapM_DirectIO f (PS fptr offset len) = do | |
let start = unsafeForeignPtrToPtr fptr `plusPtr` offset | |
end = start `plusPtr` len | |
loop ptr | |
| ptr >= end = touchForeignPtr fptr | |
| otherwise = peek ptr >>= f >> loop (ptr `plusPtr` 1) | |
loop start | |
mapM_Index :: Monad m => (Word8 -> m ()) -> ByteString -> m () | |
mapM_Index f bs = | |
loop 0 | |
where | |
len = S.length bs | |
loop i | |
| i >= len = return () | |
| otherwise = do | |
f $! S.index bs i | |
loop $! succ i | |
mapM_IndexUnsafe :: Monad m => (Word8 -> m ()) -> ByteString -> m () | |
mapM_IndexUnsafe f bs = | |
loop 0 | |
where | |
len = S.length bs | |
loop i | |
| i >= len = return () | |
| otherwise = do | |
f $! unsafeIndex bs i | |
loop $! succ i | |
mapM_IndexReverse :: Monad m => (Word8 -> m ()) -> ByteString -> m () | |
mapM_IndexReverse f bs = | |
loop (S.length bs) | |
where | |
loop 0 = return () | |
loop i = do | |
let i' = i - 1 | |
f $! S.index bs i' | |
loop i' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment