Last active April 30, 2024 18:49
{-# language
, LambdaCase
, MagicHash
, PatternSynonyms
, Strict
, TypeApplications
, UnboxedTuples
, ViewPatterns
{-# options_ghc
{- cabal:
build-depends: base >= 4.19, bytestring, mmap, async
default-language: GHC2021
ghc-options: -Wall -O2 -fllvm -rtsopts -threaded -split-sections
-- more debugging:
-- ghc -O2 -fllvm -rtsopts -threaded -split-sections -ddump-simpl -dsuppress-all
-- -dno-suppress-type-signatures -ddump-to-file -fforce-recomp
-- display output
-- should be power of 2, minimum 16384
-- #define TABLE_SIZE 131072
-- #define TABLE_SIZE 65536
#define TABLE_SIZE 32768
-- #define TABLE_SIZE 16384
import Control.Concurrent
import Control.Monad
import Data.Bits
import Foreign.Marshal.Alloc
import GHC.Exts
import GHC.IO
import GHC.Word
import System.IO.MMap
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy.Char8 as LC8
import Data.List
import Text.Printf
import System.IO hiding (withFile)
-- Random common functions
mapConcurrently :: (a -> IO b) -> [a] -> IO [b]
mapConcurrently f xs = do
caps <- getNumCapabilities
unless (caps == length xs) $ error "wrong number of capabilities"
vs <- forM (zip [0..] xs) \(i, x) -> do
v <- newEmptyMVar
v <$ forkOn i do
y <- f x
putMVar v y
forM vs takeMVar
fi :: (Integral a, Num b) => a -> b
fi = fromIntegral; {-# inline fi #-}
sl :: Bits a => a -> Int -> a
sl = unsafeShiftL
sr :: Bits a => a -> Int -> a
sr = unsafeShiftR
isrl :: Int -> Int -> Int
isrl (I# x) (I# y) = I# (uncheckedIShiftRL# x y)
max' :: Int -> Int -> Int
max' a b = let diff = a - b in a - (diff .&. sr diff 63)
min' :: Int -> Int -> Int
min' a b = let diff = a - b in b + (diff .&. sr diff 63)
plusAddr :: Addr# -> Int -> Addr#
plusAddr p (I# x) = plusAddr# p x
int2Addr :: Int -> Addr#
int2Addr (I# x) = int2Addr# x
addr2Int :: Addr# -> Int
addr2Int p = I# (addr2Int# p)
eqI :: Int -> Int -> Int
eqI (I# x) (I# y) = I# (x ==# y)
readI :: Addr# -> IO Int
readI p = IO \s -> case readIntOffAddr# p 0# s of (# s, x #) -> (# s, I# x #)
writeI :: Addr# -> Int -> IO ()
writeI p (I# x) = IO \s -> case writeIntOffAddr# p 0# x s of s -> (# s, () #)
-- Generic buffers
data Buffer = Buffer {_ptr :: Addr#, len :: Int}
plus :: Buffer -> Int -> Buffer
plus (Buffer p l) (I# x) = Buffer (plusAddr# p x) (l - I# x)
memset :: Buffer -> Word8 -> IO ()
memset (Buffer p (I# l)) (W8# x) = IO \s ->
case setAddrRange# p l (word2Int# (word8ToWord# x)) s of s -> (# s, () #)
withFile :: FilePath -> (Buffer -> IO a) -> IO a
withFile path k = mmapWithFilePtr path ReadOnly Nothing \(Ptr p, l) -> k (Buffer p l)
{-# inline withFile #-}
indexW8 :: Buffer -> Int -> Word8
indexW8 (Buffer p _) (I# x) = W8# (indexWord8OffAddr# p x)
indexW32 :: Buffer -> Int -> Word32
indexW32 (Buffer p _) (I# x) = W32# (indexWord32OffAddr# p x)
indexW :: Buffer -> Int -> Word
indexW (Buffer p _) (I# x) = W# (indexWordOffAddr# p x)
indexI :: Buffer -> Int -> Int
indexI (Buffer p _) (I# x) = I# (indexIntOffAddr# p x)
getW8 = (`indexW8` 0)
getW32 = (`indexW32` 0)
getW = (`indexW` 0)
instance Eq Buffer where
Buffer p l == Buffer p' l' = l == l' && go p p' l where
buf p = Buffer p l
go p p' l
| l >= 8 = getW (buf p) == getW (buf p') && go (plusAddr# p 8#) (plusAddr# p' 8#) (l - 8)
| l >= 4 = getW32 (buf p) == getW32 (buf p') && go (plusAddr# p 4#) (plusAddr# p' 4#) (l - 4)
| l == 0 = True
| True = getW8 (buf p) == getW8 (buf p') && go (plusAddr# p 1#) (plusAddr# p' 1#) (l - 1)
{-# inline (==) #-}
foldedMul :: Word -> Word -> Word
foldedMul (W# x) (W# y) = case timesWord2# x y of (# hi, lo #) -> W# (xor# hi lo)
salt :: Word
salt = 3032525626373534813
combine :: Word -> Word -> Word
combine x y = foldedMul (xor x y) 11400714819323198549
hashBuffer :: Buffer -> Word
hashBuffer p = go p salt where
go p acc
| len p >= 8 = go (plus p 8) (combine (getW p) acc)
| len p >= 4 = go (plus p 4) (combine (fromIntegral (getW32 p)) acc)
| len p == 0 = acc
| otherwise = go (plus p 1) (combine (fromIntegral (getW8 p)) acc)
buildBuffer :: Buffer -> BB.Builder
buildBuffer b | len b == 0 = mempty
buildBuffer b = BB.word8 (getW8 b) <> buildBuffer (plus b 1)
-- printBuffer :: Buffer -> IO ()
-- printBuffer = BB.hPutBuilder stdout . buildBuffer
instance Show Buffer where
show x =
LC8.unpack $ BB.toLazyByteString $ buildBuffer x
instance Ord Buffer where
compare x x' = compare (show x) (show x')
-- Short buffer
-- Unboxed buffer containing at most 23 bytes. The first field is the length,
-- the rest is the payload. The 24-th byte in the payload is always zeroed out.
data ShortBuffer = ShortBuffer# Int Int Int Int
instance Eq ShortBuffer where
ShortBuffer# _ a b c == ShortBuffer# _ a' b' c' =
(eqI a a' .&. eqI b b' .&. eqI c c') == 1
{-# inline (==) #-}
hashShortBuffer :: ShortBuffer -> Word
hashShortBuffer (ShortBuffer# _ a b c) =
(salt `combine` fi a) `combine` (fi b `combine` fi c)
buildShortBuffer :: ShortBuffer -> BB.Builder
buildShortBuffer (ShortBuffer# l a b c) =
BB.lazyByteString $ LC8.take (fi l) $ BB.toLazyByteString $
BB.int64LE (fi c) <> BB.int64LE (fi b) <> BB.int64LE (fi a)
instance Show ShortBuffer where
show = LC8.unpack . BB.toLazyByteString . buildShortBuffer
instance Ord ShortBuffer where
compare (ShortBuffer# _ a b c) (ShortBuffer# _ a' b' c') =
let sw (I# x) = W# (byteSwap# (int2Word# x))
in compare (sw c) (sw c') <> compare (sw b) (sw b') <> compare (sw a) (sw a')
-- Unboxed sum of short and standard buffers.
data SLBuffer = SLB# Int Int Int
isEmptySLB :: SLBuffer -> Bool
isEmptySLB (SLB# a _ _) = a == 0
unpackSLB# :: SLBuffer -> (# ShortBuffer | Buffer #)
unpackSLB# (SLB# a b c) =
let l = a .&. 255 in
if l <= 23 then (# ShortBuffer# l (isrl a 8) b c | #)
else (# | Buffer (int2Addr b) a #)
pattern ShortBuffer :: ShortBuffer -> SLBuffer
pattern ShortBuffer buf <- (unpackSLB# -> (# buf | #)) where
ShortBuffer (ShortBuffer# len a b c) = SLB# (sl a 8 .|. len) b c
pattern LongBuffer :: Buffer -> SLBuffer
pattern LongBuffer buf <- (unpackSLB# -> (# | buf #)) where
LongBuffer (Buffer p l) = SLB# l (addr2Int p) 0
{-# complete ShortBuffer, LongBuffer #-}
instance Eq SLBuffer where
ShortBuffer b == ShortBuffer b' = b == b'
LongBuffer b == LongBuffer b' = b == b'
_ == _ = False
{-# inline (==) #-}
-- Try to pack a Buffer into a short one.
packBuffer :: Buffer -> SLBuffer
packBuffer b =
let l = len b
ix = indexI b
mask l = isrl (-1) (64 - sl l 3) in
if l <= 8 then ShortBuffer (ShortBuffer# l 0 0 (ix 0 .&. mask l))
else if l <= 16 then ShortBuffer (ShortBuffer# l 0 (ix 1 .&. mask (l - 8)) (ix 0))
else if l <= 23 then ShortBuffer (ShortBuffer# l (ix 2 .&. mask (l - 16)) (ix 1) (ix 0))
else LongBuffer b
hashSLB :: SLBuffer -> Word
hashSLB (ShortBuffer b) = hashShortBuffer b
hashSLB (LongBuffer b) = hashBuffer b
buildSLB :: SLBuffer -> BB.Builder
buildSLB (ShortBuffer b) = buildShortBuffer b
buildSLB (LongBuffer b) = buildBuffer b
instance Show SLBuffer where
show = LC8.unpack . BB.toLazyByteString . buildSLB
instance Ord SLBuffer where
compare (ShortBuffer b) (ShortBuffer b') = compare b b'
compare b b' = compare (show b) (show b')
-- Branchless scanning for bytes in words.
#define SCAN_MASK(hex) 0x/**/hex/**/hex/**/hex/**/hex/**/hex/**/hex/**/hex/**/hex
-- Given a hexadecimal byte, generate the (Word -> Int) function which returns the
-- index of the rightmost occurrence of the byte, or returns 8 if the byte does not
-- occur.
#define BYTE_INDEX(hex) (\(x :: Word) -> case xor x SCAN_MASK(hex) of \
x -> case (x - 0x0101010101010101) .&. complement x .&. 0x8080808080808080 of \
x -> countTrailingZeros x `sr` 3)
-- Hash table of measurements
data Val = Val {
_min :: Int
, _max :: Int
, _cnt :: Int
, _total :: Int
data Entry = Entry {
_key :: {-# unpack #-} SLBuffer
, _val :: {-# unpack #-} Val
-- size of entry in bytes (includes padding to 64 bytes!)
entrySize :: Int
entrySize = 8 * 8
tableMask :: Int
tableMask = TABLE_SIZE - 1
tableBytes :: Int
tableBytes = TABLE_SIZE * entrySize
type Table = Addr#
initTables :: [Buffer] -> ([(Buffer, Ptr Word8)] -> IO a) -> IO a
initTables bs f = do
let l = TABLE_SIZE * entrySize
let go [] acc = f acc
go (b:bs) acc = allocaBytesAligned l entrySize \p@(Ptr p') -> do
memset (Buffer p' l) 0
go bs ((b, p):acc)
go bs []
-- read entry from a *byte* offset
readEntry :: Table -> Int -> IO Entry
readEntry p i = case plusAddr p i of
p -> do
a <- readI p
b <- readI (plusAddr p 8)
c <- readI (plusAddr p 16)
d <- readI (plusAddr p 24)
e <- readI (plusAddr p 32)
f <- readI (plusAddr p 40)
g <- readI (plusAddr p 48)
pure $ Entry (SLB# a b c) (Val d e f g)
-- write entry to a *byte* offset
writeEntry :: Table -> Int -> Entry -> IO ()
writeEntry p i (Entry (SLB# a b c) (Val d e f g)) = case plusAddr p i of
p -> do
writeI p a
writeI (plusAddr p 8) b
writeI (plusAddr p 16) c
writeI (plusAddr p 24) d
writeI (plusAddr p 32) e
writeI (plusAddr p 40) f
writeI (plusAddr p 48) g
newVal :: Int -> Val
newVal temp = Val temp temp 1 temp
updateEntry :: Entry -> Val -> Entry
updateEntry (Entry k (Val mi ma cn to)) (Val mi' ma' cn' to')
= Entry k (Val (min' mi mi') (max' ma ma') (cn + cn') (to + to'))
forTable :: Table -> (Entry -> IO ()) -> IO ()
forTable t f = do
let go ix | ix == tableBytes = pure ()
go ix = do
e@(Entry k _) <- readEntry t ix
if isEmptySLB k then do
go (ix + entrySize)
else do
f e
go (ix + entrySize)
go 0
{-# inline forTable #-}
updateTable :: Table -> Entry -> IO ()
updateTable tbl e@(Entry key val) = do
let go ix | ix == tableBytes = go 0
go ix = do
olde@(Entry oldkey _) <- readEntry tbl ix
if isEmptySLB oldkey then do
writeEntry tbl ix e
else if key == oldkey then do
writeEntry tbl ix (updateEntry olde val)
else do
go (ix + entrySize)
go ((fi (hashSLB key) .&. tableMask) * entrySize)
parse :: Table -> Buffer -> IO ()
parse _ b | len b == 0 = do
pure ()
parse tbl b = do
-- scan for semicolon
let findSemi :: Int -> Buffer -> Int
findSemi i b = case BYTE_INDEX(3B) (getW b) of
8 -> findSemi (i + 8) (plus b 8)
i' -> i + i'
let keylen = findSemi 0 b
let key = packBuffer $ b {len = keylen}
b <- pure $ plus b (keylen + 1)
let digit :: Word8 -> Int
digit x = fi x - 48
let join :: Buffer -> Int -> IO ()
join b temp = do
updateTable tbl (Entry key (newVal temp))
parse tbl b
case getW8 b of
-- '-'
45 -> do
let d1 = getW8 (plus b 1)
case getW8 (plus b 2) of
-- '.' so the next must be digit
46 -> do
let d2 = getW8 (plus b 3)
join (plus b 5) ((-10)*(digit d1) - digit d2)
-- digit, so the next must be '.' and then digit
d2 -> do
let d3 = getW8 (plus b 4)
join (plus b 6) ((-100)*(digit d1) - 10*(digit d2) - digit d3)
-- a digit
d1 -> case getW8 (plus b 1) of
-- '.', so the next must be digit
46 -> do
let d2 = getW8 (plus b 2)
join (plus b 4) (10*digit d1 + digit d2)
-- another digit, so the next must be '.', and then digit
d2 -> do
let d3 = getW8 (plus b 3)
join (plus b 5) (100*digit d1 + 10*digit d2 + digit d3)
-- Split file to THREAD_NUM buffers
splitBuffer :: Int -> Buffer -> [Buffer]
splitBuffer num_threads b = let
chunkSize = div (len b) num_threads
go b | len b <= chunkSize =
go b = let
findNewl i b = case BYTE_INDEX(0A) (getW b) of
8 -> findNewl (i + 8) (plus b 8)
i' -> i + i'
keylen = findNewl 0 (plus b chunkSize)
chunkSize' = chunkSize + keylen + 1
rest = go (plus b chunkSize')
Buffer (_ptr b) chunkSize' : rest
in go b
tableToList :: Table -> IO [Entry]
tableToList tbl = do
let go ix | ix == tableBytes = pure []
go ix = do
e@(Entry k _) <- readEntry tbl ix
if isEmptySLB k then do
go (ix + entrySize)
else do
es <- go (ix + entrySize)
pure (e:es)
go 0
displayEntries :: [Entry] -> BB.Builder
displayEntries es = BB.char8 '{' <> go es <> BB.char8 '}' where
f $$! x = f x; infixl 8 $$!
goEntry (Entry key (Val mi ma cn to)) =
buildSLB key <>
(printf "=%.1f/%.1f/%.1f" $$!
(fi mi / 10 :: Double) $$!
(fi to / (fi cn * 10) :: Double) $$!
(fi ma / 10 :: Double))
go [] = mempty
go [e] = goEntry e
go (e:es) = goEntry e <> BB.string8 ", " <> go es
main :: IO ()
main =
withFile "data/measurements500M.txt" \b -> do
num_threads <- getNumCapabilities
initTables (splitBuffer num_threads b) \bts -> do
Ptr tbl:ts <- mapConcurrently (\(b, Ptr t) -> Ptr t <$ parse t b) bts
forM_ ts \(Ptr tbl') ->
forTable tbl' \e ->
updateTable tbl e
es <- sortBy (\e e' -> compare (_key e) (_key e')) <$> tableToList tbl
BB.hPutBuilder stdout (displayEntries es)
putChar '\n'
