Skip to content

Instantly share code, notes, and snippets.

@vshabanov
Last active April 30, 2024 18:46
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vshabanov/c34aedf388470402dc0a2eee1ffc062b to your computer and use it in GitHub Desktop.
Save vshabanov/c34aedf388470402dc0a2eee1ffc062b to your computer and use it in GitHub Desktop.
One billion lines challenge
{-# LANGUAGE GHC2021, LambdaCase, PatternSynonyms, RecordWildCards,
ViewPatterns, OverloadedStrings, GADTs, NoMonoLocalBinds,
UnboxedTuples, MagicHash #-}
{-# OPTIONS_GHC -O2 -fspec-constr-count=100 -fllvm #-}
-- need at least 8 to completely remove all allocations
-- (I don't know what change lead to it)
{-# OPTIONS_GHC -Wall -Wno-gadt-mono-local-binds -Wno-type-defaults #-}
module Main (main) where
import Control.Concurrent
import Control.Concurrent.Async
import Control.Monad
import Data.Bits
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
import Data.Function (on, fix)
import qualified Data.HashMap.Strict as HM
import Data.List (sortBy, tails, foldl')
import Data.Maybe (catMaybes, fromMaybe)
import Data.Ord (comparing)
import Data.Word
import Foreign.ForeignPtr
import Foreign.Marshal.Alloc
import Foreign.Marshal.Utils
import Foreign.Ptr
import Foreign.Storable
import System.IO
import Text.Printf
import GHC.Base hiding (O)
import Data.Primitive
nProcessThreads :: Int
nProcessThreads = 8
chunkSize :: Num a => a
chunkSize = 1024*1024
maxStationNameLength :: Int
maxStationNameLength = 100
maxKeysCount :: Num a => a
maxKeysCount = 10_000
maxLineLength :: Int
maxLineLength = maxStationNameLength + length (";-12.3\n" :: String)
data Totals
= Totals
{ tCount :: !Int
, tMin :: !Double
, tMax :: !Double
, tSum :: !Double
}
deriving Show
instance Semigroup Totals where
a <> b =
Totals
{ tCount = o (+) tCount
, tMin = o min tMin
, tMax = o max tMax
, tSum = o (+) tSum
}
where
o op f = (op `on` f) a b
main :: IO ()
main = do
setNumCapabilities nProcessThreads
reader <- mkChunkReader "measurements.txt"
results <- mapConcurrently stations $ replicate nProcessThreads reader
let sorted = sortBy (comparing fst)
$ HM.toList $ foldl' (HM.unionWith (<>)) mempty results
isLast = map null $ drop 1 $ tails sorted
putStr "{"
forM_ (zip sorted isLast) $ \ ((n, Totals{..}), lastEntry) -> do
B.putStr n
printf "=%.1f/%.1f/%.1f" tMin (tSum / fromIntegral tCount) tMax
unless lastEntry $ putStr ", "
putStrLn "}"
-- | Iterate over file chunks in an unspecified order.
-- Can be run from several threads and will split chunks between them
-- without duplication.
type ChunkReader = (B.ByteString -> IO ()) -> IO ()
mkChunkReader :: FilePath -> IO ChunkReader
mkChunkReader fn = do
s <- withF hFileSize
let nChunks = (s + chunkSize - 1) `div` chunkSize
chunkIds <- newMVar [0..nChunks-1]
-- return a chunk reader that can be used by processing threads
pure $ \ handleChunk -> do
buf <- B.mallocByteString bufSize
withF $ \ h -> fix $ \ loopHandle ->
join $ modifyMVar chunkIds $ pure . \ case
[] -> ([], pure ()) -- no more chunks
(i:is) -> (is, do
bs <- readChunk h i buf
handleChunk bs
loopHandle)
where
withF = withFile fn ReadMode
readChunk h i buf = do
-- PreviousSta[tion;1.2\n...LastSta|tion:3.4\nAnotherStation...]
-- [ chunkSize | maxLineLength ]
-- ^ hSeek + hGet
-- trimPreviousLine |
-- [...LastSta|tion:3.4\nAnotherStation...]
-- trimLastLine |
-- [...LastSta|tion:3.4\n]
-- ^ imaginary separator
hSeek h AbsoluteSeek (offset i)
trimPreviousLine i . trimLastLine i . B.fromForeignPtr0 buf
<$> withForeignPtr buf (\ p -> hGetBuf h p bufSize)
bufSize = chunkSize+maxLineLength+8
-- +8 for null termination and possible extra reads in hashPtr
trimLastLine i s
| B.length s > chunkSize
= B.take (chunkSize + overflowNewlineIndex i s + 1) s
| otherwise = s
overflowNewlineIndex i s =
fromMaybe
(error $ printf "line end is not found for the chunk #%d (offset=%d)"
i (offset i))
(B.findIndex (== B.c2w '\n') $ B.drop chunkSize s)
trimPreviousLine i s
| i == 0 = s -- no previous line for the first chunk
| otherwise = B.tail $ B.dropWhile (/= B.c2w '\n') s
offset i = i * chunkSize
data Entry
-- | Fields in our hash table
data Field a where
KeyPtr :: Field (Ptr Word8)
KeyLen :: Field Int
Count :: Field Int
Min :: Field Double
Max :: Field Double
Sum :: Field Double
-- | Field offset with a 'Storable' instance
data StorableAndOffset a where
O :: Storable a => Int -> StorableAndOffset a
storableAndOffset :: Field a -> StorableAndOffset a
storableAndOffset = \ case
KeyPtr -> O 0
KeyLen -> O 8
Count -> O $ 2*8
Min -> O $ 3*8
Max -> O $ 4*8
Sum -> O $ 5*8
entrySize :: Int
entrySize = 6*8
indexRead :: Ptr Entry -> Int -> Field a -> IO a
indexRead p i f = case storableAndOffset f of
O o -> peekByteOff p (i*entrySize + o)
indexWrite :: Ptr Entry -> Int -> Field a -> a -> IO ()
indexWrite p i f x = case storableAndOffset f of
O o -> pokeByteOff p (i*entrySize + o) x
withArena :: Int -> (B.ByteString -> Ptr a -> IO b) -> IO b
withArena len f = do
fp <- B.mallocByteString len
withForeignPtr fp $ \ p -> f (B.fromForeignPtr0 fp len) (castPtr p)
stations :: ChunkReader -> IO (HM.HashMap B.ByteString Totals)
stations readChunks =
let tableSize = 2 ^ (ceiling (logBase 2 maxKeysCount) + 2) -- 2^16 for 10k
sizeOfTable = entrySize * tableSize
sizeOfKeysArena = maxKeysCount * maxStationNameLength
in
withArena sizeOfTable $ \ _ table -> -- linear probe hash table
withArena sizeOfKeysArena $ \ keysFP keys -> -- arena to copy keys
alloca $ \ keyPtr -> do -- the last key location
fillBytes table 0 sizeOfTable
poke keyPtr keys
let
station start = do
c0 <- peek start
when (c0 /= 0) $ do
sc <- B.memchr start (B.c2w ';') (toEnum maxStationNameLength)
let hash = hashPtr start len
len = sc `minusPtr` start
(!n, !xs) = parseDegrees $ plusPtr sc 1
-- 13s->10s (for 100M rows) from !n, less allocs from !xs'
-- 10s->1.7s when no HashMap is involved
-- 10s->7s HashMap -> linear scan table (unboxed-containers are quite fast)
-- 7s->5s array->c-like flat data structure (any memory indirection matters)
loop bucket = do
let r = indexRead table bucket
w = indexWrite table bucket
m c f = w c . f =<< r c
l <- r KeyLen
if l == 0 then do
key <- peek keyPtr
poke keyPtr (key `plusPtr` len)
copyBytes key start len -- store the new key to arena
w KeyPtr key
w KeyLen len
w Count 1
w Min n
w Max n
w Sum n
station xs
else if l == len then do
entryKey <- r KeyPtr
c <- B.memcmp entryKey start len
if c == 0 then do
m Count succ
m Min (min n)
m Max (max n)
m Sum (+n)
station xs
else
loop (succ bucket .&. (tableSize-1))
else
loop (succ bucket .&. (tableSize-1))
loop (hash .&. (tableSize-1))
parseDegrees = \ case
'-' :. d1 :. d0 :. '.' :. f :. '\n' :. xs ->
(-(10*d d1 + d d0 + 0.1*d f), xs)
'-' :. d0 :. '.' :. f :. '\n' :. xs ->
(-( d d0 + 0.1*d f), xs)
d1 :. d0 :. '.' :. f :. '\n' :. xs -> (10*d d1 + d d0 + 0.1*d f, xs)
d0 :. '.' :. f :. '\n' :. xs -> ( d d0 + 0.1*d f, xs)
_ -> error "bad number"
d x = fromIntegral $ ord x - ord '0'
readChunks $ \ (B.BS fp len) -> withForeignPtr fp $ \ s -> do
pokeByteOff s len (0::Word8)
station s
fmap (HM.fromList . catMaybes) $ forM [0..tableSize-1] $ \ bucket -> do
let r = indexRead table bucket
key <- r KeyPtr
if key /= nullPtr then do
len <- r KeyLen
fmap (Just . (B.take len $ B.drop (minusPtr key keys) keysFP,))
$ Totals <$> r Count <*> r Min <*> r Max <*> r Sum
else
pure Nothing
pattern (:.) :: Char -> Ptr Word8 -> Ptr Word8
pattern x :. xs <- (unconsPtr -> Just (B.w2c -> x, xs))
infixr 5 :.
unconsPtr :: Ptr Word8 -> Maybe (Word8, Ptr Word8)
unconsPtr p = case B.accursedUnutterablePerformIO $ peek p of
0 -> Nothing
-- need to null-terminate string for this, but it works ~10% faster
-- than having a separate length
x -> Just (x, plusPtr p 1)
------------------------------------------------------------------------------
-- András Kovács hash
-- see https://discourse.haskell.org/t/one-billion-row-challenge-in-hs/8946/141
-- and https://gist.github.com/AndrasKovacs/e156ae66b8c28b1b84abe6b483ea20ec
-- Helps to shave another 15-20% of runtime thanks to vectorization
-- (memchr + 64 bit hash is faster than a byte-by-byte iteration)
-- Note that on older machines it could actually be 10-20% slower
-- than previous version
-- this function is extremely unsafe as it can perform out of bounds reads
-- but we have padding at the end of the read buffer for this
{-# INLINE hashPtr #-}
hashPtr :: Ptr Word8 -> Int -> Int
hashPtr ptr@(Ptr p) l
| l <= 8 = h3 0 0 (mr l 0)
| l <= 16 = h3 0 (mr (l - 8) 1) (r 0)
| l <= 24 = h3 (mr (l - 16) 2) (r 1) (r 0)
| otherwise = h3 0 (hashPtr ptr 24) (hashPtr (ptr `plusPtr` 24) (l - 24))
where
r (I# i) = I# (indexIntOffAddr# p i)
mr len i = r i .&. mask len
mask len = unsafeIShiftRL ((-1) :: Int) (64 - unsafeShiftL len 3)
h3 (I# a) (I# b) (I# c) =
I# (word2Int# (int2Word# a `combine` int2Word# b `combine` int2Word# c))
foldedMul :: Word# -> Word# -> Word#
foldedMul x y = case timesWord2# x y of (# hi, lo #) -> xor# hi lo
combine :: Word# -> Word# -> Word#
combine x y = foldedMul (xor# x y) 11400714819323198549##
unsafeIShiftRL :: Int -> Int -> Int
unsafeIShiftRL (I# x) (I# y) = I# (uncheckedIShiftRL# x y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment