Skip to content

Instantly share code, notes, and snippets.

@AndrasKovacs
Last active April 30, 2024 18:49
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save AndrasKovacs/e156ae66b8c28b1b84abe6b483ea20ec to your computer and use it in GitHub Desktop.
Save AndrasKovacs/e156ae66b8c28b1b84abe6b483ea20ec to your computer and use it in GitHub Desktop.
1brc
{-# language
BlockArguments
, CPP
, LambdaCase
, MagicHash
, PatternSynonyms
, Strict
, TypeApplications
, UnboxedTuples
, ViewPatterns
#-}
{-# options_ghc
-Wall
-Wno-missing-signatures
-Wno-name-shadowing
#-}
{- cabal:
build-depends: base >= 4.19, bytestring, mmap, async
default-language: GHC2021
ghc-options: -Wall -O2 -fllvm -rtsopts -threaded -split-sections
-}
-- more debugging:
-- ghc -O2 -fllvm -rtsopts -threaded -split-sections -ddump-simpl -dsuppress-all
-- -dno-suppress-type-signatures -ddump-to-file -fforce-recomp
-- CONFIGURATION
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
-- display output
#define DISPLAY_OUTPUT
-- should be power of 2, minimum 16384
-- #define TABLE_SIZE 131072
-- #define TABLE_SIZE 65536
#define TABLE_SIZE 32768
-- #define TABLE_SIZE 16384
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
import Control.Concurrent
import Control.Monad
import Data.Bits
import Foreign.Marshal.Alloc
import GHC.Exts
import GHC.IO
import GHC.Word
import System.IO.MMap
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy.Char8 as LC8
#ifdef DISPLAY_OUTPUT
import Data.List
import Text.Printf
import System.IO hiding (withFile)
#endif
-- Random common functions
--------------------------------------------------------------------------------
mapConcurrently :: (a -> IO b) -> [a] -> IO [b]
mapConcurrently f xs = do
caps <- getNumCapabilities
unless (caps == length xs) $ error "wrong number of capabilities"
vs <- forM (zip [0..] xs) \(i, x) -> do
v <- newEmptyMVar
v <$ forkOn i do
yield
y <- f x
putMVar v y
forM vs takeMVar
fi :: (Integral a, Num b) => a -> b
fi = fromIntegral; {-# inline fi #-}
sl :: Bits a => a -> Int -> a
sl = unsafeShiftL
sr :: Bits a => a -> Int -> a
sr = unsafeShiftR
isrl :: Int -> Int -> Int
isrl (I# x) (I# y) = I# (uncheckedIShiftRL# x y)
max' :: Int -> Int -> Int
max' a b = let diff = a - b in a - (diff .&. sr diff 63)
min' :: Int -> Int -> Int
min' a b = let diff = a - b in b + (diff .&. sr diff 63)
plusAddr :: Addr# -> Int -> Addr#
plusAddr p (I# x) = plusAddr# p x
int2Addr :: Int -> Addr#
int2Addr (I# x) = int2Addr# x
addr2Int :: Addr# -> Int
addr2Int p = I# (addr2Int# p)
eqI :: Int -> Int -> Int
eqI (I# x) (I# y) = I# (x ==# y)
readI :: Addr# -> IO Int
readI p = IO \s -> case readIntOffAddr# p 0# s of (# s, x #) -> (# s, I# x #)
writeI :: Addr# -> Int -> IO ()
writeI p (I# x) = IO \s -> case writeIntOffAddr# p 0# x s of s -> (# s, () #)
-- Generic buffers
--------------------------------------------------------------------------------
data Buffer = Buffer {_ptr :: Addr#, len :: Int}
plus :: Buffer -> Int -> Buffer
plus (Buffer p l) (I# x) = Buffer (plusAddr# p x) (l - I# x)
memset :: Buffer -> Word8 -> IO ()
memset (Buffer p (I# l)) (W8# x) = IO \s ->
case setAddrRange# p l (word2Int# (word8ToWord# x)) s of s -> (# s, () #)
withFile :: FilePath -> (Buffer -> IO a) -> IO a
withFile path k = mmapWithFilePtr path ReadOnly Nothing \(Ptr p, l) -> k (Buffer p l)
{-# inline withFile #-}
indexW8 :: Buffer -> Int -> Word8
indexW8 (Buffer p _) (I# x) = W8# (indexWord8OffAddr# p x)
indexW32 :: Buffer -> Int -> Word32
indexW32 (Buffer p _) (I# x) = W32# (indexWord32OffAddr# p x)
indexW :: Buffer -> Int -> Word
indexW (Buffer p _) (I# x) = W# (indexWordOffAddr# p x)
indexI :: Buffer -> Int -> Int
indexI (Buffer p _) (I# x) = I# (indexIntOffAddr# p x)
getW8 = (`indexW8` 0)
getW32 = (`indexW32` 0)
getW = (`indexW` 0)
instance Eq Buffer where
Buffer p l == Buffer p' l' = l == l' && go p p' l where
buf p = Buffer p l
go p p' l
| l >= 8 = getW (buf p) == getW (buf p') && go (plusAddr# p 8#) (plusAddr# p' 8#) (l - 8)
| l >= 4 = getW32 (buf p) == getW32 (buf p') && go (plusAddr# p 4#) (plusAddr# p' 4#) (l - 4)
| l == 0 = True
| True = getW8 (buf p) == getW8 (buf p') && go (plusAddr# p 1#) (plusAddr# p' 1#) (l - 1)
{-# inline (==) #-}
foldedMul :: Word -> Word -> Word
foldedMul (W# x) (W# y) = case timesWord2# x y of (# hi, lo #) -> W# (xor# hi lo)
salt :: Word
salt = 3032525626373534813
combine :: Word -> Word -> Word
combine x y = foldedMul (xor x y) 11400714819323198549
hashBuffer :: Buffer -> Word
hashBuffer p = go p salt where
go p acc
| len p >= 8 = go (plus p 8) (combine (getW p) acc)
| len p >= 4 = go (plus p 4) (combine (fromIntegral (getW32 p)) acc)
| len p == 0 = acc
| otherwise = go (plus p 1) (combine (fromIntegral (getW8 p)) acc)
buildBuffer :: Buffer -> BB.Builder
buildBuffer b | len b == 0 = mempty
buildBuffer b = BB.word8 (getW8 b) <> buildBuffer (plus b 1)
-- printBuffer :: Buffer -> IO ()
-- printBuffer = BB.hPutBuilder stdout . buildBuffer
instance Show Buffer where
show x =
LC8.unpack $ BB.toLazyByteString $ buildBuffer x
instance Ord Buffer where
compare x x' = compare (show x) (show x')
-- Short buffer
--------------------------------------------------------------------------------
-- Unboxed buffer containing at most 23 bytes. The first field is the length,
-- the rest is the payload. The 24-th byte in the payload is always zeroed out.
data ShortBuffer = ShortBuffer# Int Int Int Int
instance Eq ShortBuffer where
ShortBuffer# _ a b c == ShortBuffer# _ a' b' c' =
(eqI a a' .&. eqI b b' .&. eqI c c') == 1
{-# inline (==) #-}
hashShortBuffer :: ShortBuffer -> Word
hashShortBuffer (ShortBuffer# _ a b c) =
(salt `combine` fi a) `combine` (fi b `combine` fi c)
buildShortBuffer :: ShortBuffer -> BB.Builder
buildShortBuffer (ShortBuffer# l a b c) =
BB.lazyByteString $ LC8.take (fi l) $ BB.toLazyByteString $
BB.int64LE (fi c) <> BB.int64LE (fi b) <> BB.int64LE (fi a)
instance Show ShortBuffer where
show = LC8.unpack . BB.toLazyByteString . buildShortBuffer
instance Ord ShortBuffer where
compare (ShortBuffer# _ a b c) (ShortBuffer# _ a' b' c') =
let sw (I# x) = W# (byteSwap# (int2Word# x))
in compare (sw c) (sw c') <> compare (sw b) (sw b') <> compare (sw a) (sw a')
-- Unboxed sum of short and standard buffers.
--------------------------------------------------------------------------------
data SLBuffer = SLB# Int Int Int
isEmptySLB :: SLBuffer -> Bool
isEmptySLB (SLB# a _ _) = a == 0
unpackSLB# :: SLBuffer -> (# ShortBuffer | Buffer #)
unpackSLB# (SLB# a b c) =
let l = a .&. 255 in
if l <= 23 then (# ShortBuffer# l (isrl a 8) b c | #)
else (# | Buffer (int2Addr b) a #)
pattern ShortBuffer :: ShortBuffer -> SLBuffer
pattern ShortBuffer buf <- (unpackSLB# -> (# buf | #)) where
ShortBuffer (ShortBuffer# len a b c) = SLB# (sl a 8 .|. len) b c
pattern LongBuffer :: Buffer -> SLBuffer
pattern LongBuffer buf <- (unpackSLB# -> (# | buf #)) where
LongBuffer (Buffer p l) = SLB# l (addr2Int p) 0
{-# complete ShortBuffer, LongBuffer #-}
instance Eq SLBuffer where
ShortBuffer b == ShortBuffer b' = b == b'
LongBuffer b == LongBuffer b' = b == b'
_ == _ = False
{-# inline (==) #-}
-- Try to pack a Buffer into a short one.
packBuffer :: Buffer -> SLBuffer
packBuffer b =
let l = len b
ix = indexI b
mask l = isrl (-1) (64 - sl l 3) in
if l <= 8 then ShortBuffer (ShortBuffer# l 0 0 (ix 0 .&. mask l))
else if l <= 16 then ShortBuffer (ShortBuffer# l 0 (ix 1 .&. mask (l - 8)) (ix 0))
else if l <= 23 then ShortBuffer (ShortBuffer# l (ix 2 .&. mask (l - 16)) (ix 1) (ix 0))
else LongBuffer b
hashSLB :: SLBuffer -> Word
hashSLB (ShortBuffer b) = hashShortBuffer b
hashSLB (LongBuffer b) = hashBuffer b
buildSLB :: SLBuffer -> BB.Builder
buildSLB (ShortBuffer b) = buildShortBuffer b
buildSLB (LongBuffer b) = buildBuffer b
instance Show SLBuffer where
show = LC8.unpack . BB.toLazyByteString . buildSLB
instance Ord SLBuffer where
compare (ShortBuffer b) (ShortBuffer b') = compare b b'
compare b b' = compare (show b) (show b')
-- Branchless scanning for bytes in words.
--------------------------------------------------------------------------------
#define SCAN_MASK(hex) 0x/**/hex/**/hex/**/hex/**/hex/**/hex/**/hex/**/hex/**/hex
-- Given a hexadecimal byte, generate the (Word -> Int) function which returns the
-- index of the rightmost occurrence of the byte, or returns 8 if the byte does not
-- occur.
#define BYTE_INDEX(hex) (\(x :: Word) -> case xor x SCAN_MASK(hex) of \
x -> case (x - 0x0101010101010101) .&. complement x .&. 0x8080808080808080 of \
x -> countTrailingZeros x `sr` 3)
-- Hash table of measurements
--------------------------------------------------------------------------------
data Val = Val {
_min :: Int
, _max :: Int
, _cnt :: Int
, _total :: Int
}
data Entry = Entry {
_key :: {-# unpack #-} SLBuffer
, _val :: {-# unpack #-} Val
}
-- size of entry in bytes (includes padding to 64 bytes!)
entrySize :: Int
entrySize = 8 * 8
tableMask :: Int
tableMask = TABLE_SIZE - 1
tableBytes :: Int
tableBytes = TABLE_SIZE * entrySize
type Table = Addr#
initTables :: [Buffer] -> ([(Buffer, Ptr Word8)] -> IO a) -> IO a
initTables bs f = do
let l = TABLE_SIZE * entrySize
let go [] acc = f acc
go (b:bs) acc = allocaBytesAligned l entrySize \p@(Ptr p') -> do
memset (Buffer p' l) 0
go bs ((b, p):acc)
go bs []
-- read entry from a *byte* offset
readEntry :: Table -> Int -> IO Entry
readEntry p i = case plusAddr p i of
p -> do
a <- readI p
b <- readI (plusAddr p 8)
c <- readI (plusAddr p 16)
d <- readI (plusAddr p 24)
e <- readI (plusAddr p 32)
f <- readI (plusAddr p 40)
g <- readI (plusAddr p 48)
pure $ Entry (SLB# a b c) (Val d e f g)
-- write entry to a *byte* offset
writeEntry :: Table -> Int -> Entry -> IO ()
writeEntry p i (Entry (SLB# a b c) (Val d e f g)) = case plusAddr p i of
p -> do
writeI p a
writeI (plusAddr p 8) b
writeI (plusAddr p 16) c
writeI (plusAddr p 24) d
writeI (plusAddr p 32) e
writeI (plusAddr p 40) f
writeI (plusAddr p 48) g
newVal :: Int -> Val
newVal temp = Val temp temp 1 temp
updateEntry :: Entry -> Val -> Entry
updateEntry (Entry k (Val mi ma cn to)) (Val mi' ma' cn' to')
= Entry k (Val (min' mi mi') (max' ma ma') (cn + cn') (to + to'))
forTable :: Table -> (Entry -> IO ()) -> IO ()
forTable t f = do
let go ix | ix == tableBytes = pure ()
go ix = do
e@(Entry k _) <- readEntry t ix
if isEmptySLB k then do
go (ix + entrySize)
else do
f e
go (ix + entrySize)
go 0
{-# inline forTable #-}
updateTable :: Table -> Entry -> IO ()
updateTable tbl e@(Entry key val) = do
let go ix | ix == tableBytes = go 0
go ix = do
olde@(Entry oldkey _) <- readEntry tbl ix
if isEmptySLB oldkey then do
writeEntry tbl ix e
else if key == oldkey then do
writeEntry tbl ix (updateEntry olde val)
else do
go (ix + entrySize)
go ((fi (hashSLB key) .&. tableMask) * entrySize)
parse :: Table -> Buffer -> IO ()
parse _ b | len b == 0 = do
pure ()
parse tbl b = do
-- scan for semicolon
let findSemi :: Int -> Buffer -> Int
findSemi i b = case BYTE_INDEX(3B) (getW b) of
8 -> findSemi (i + 8) (plus b 8)
i' -> i + i'
let keylen = findSemi 0 b
let key = packBuffer $ b {len = keylen}
b <- pure $ plus b (keylen + 1)
let digit :: Word8 -> Int
digit x = fi x - 48
let join :: Buffer -> Int -> IO ()
join b temp = do
updateTable tbl (Entry key (newVal temp))
parse tbl b
case getW8 b of
-- '-'
45 -> do
let d1 = getW8 (plus b 1)
case getW8 (plus b 2) of
-- '.' so the next must be digit
46 -> do
let d2 = getW8 (plus b 3)
join (plus b 5) ((-10)*(digit d1) - digit d2)
-- digit, so the next must be '.' and then digit
d2 -> do
let d3 = getW8 (plus b 4)
join (plus b 6) ((-100)*(digit d1) - 10*(digit d2) - digit d3)
-- a digit
d1 -> case getW8 (plus b 1) of
-- '.', so the next must be digit
46 -> do
let d2 = getW8 (plus b 2)
join (plus b 4) (10*digit d1 + digit d2)
-- another digit, so the next must be '.', and then digit
d2 -> do
let d3 = getW8 (plus b 3)
join (plus b 5) (100*digit d1 + 10*digit d2 + digit d3)
-- Split file to THREAD_NUM buffers
--------------------------------------------------------------------------------
splitBuffer :: Int -> Buffer -> [Buffer]
splitBuffer num_threads b = let
chunkSize = div (len b) num_threads
go b | len b <= chunkSize =
[b]
go b = let
findNewl i b = case BYTE_INDEX(0A) (getW b) of
8 -> findNewl (i + 8) (plus b 8)
i' -> i + i'
keylen = findNewl 0 (plus b chunkSize)
chunkSize' = chunkSize + keylen + 1
rest = go (plus b chunkSize')
in
Buffer (_ptr b) chunkSize' : rest
in go b
#ifdef DISPLAY_OUTPUT
tableToList :: Table -> IO [Entry]
tableToList tbl = do
let go ix | ix == tableBytes = pure []
go ix = do
e@(Entry k _) <- readEntry tbl ix
if isEmptySLB k then do
go (ix + entrySize)
else do
es <- go (ix + entrySize)
pure (e:es)
go 0
displayEntries :: [Entry] -> BB.Builder
displayEntries es = BB.char8 '{' <> go es <> BB.char8 '}' where
f $$! x = f x; infixl 8 $$!
goEntry (Entry key (Val mi ma cn to)) =
buildSLB key <>
BB.string8
(printf "=%.1f/%.1f/%.1f" $$!
(fi mi / 10 :: Double) $$!
(fi to / (fi cn * 10) :: Double) $$!
(fi ma / 10 :: Double))
go [] = mempty
go [e] = goEntry e
go (e:es) = goEntry e <> BB.string8 ", " <> go es
#endif
main :: IO ()
main =
withFile "data/measurements500M.txt" \b -> do
num_threads <- getNumCapabilities
initTables (splitBuffer num_threads b) \bts -> do
Ptr tbl:ts <- mapConcurrently (\(b, Ptr t) -> Ptr t <$ parse t b) bts
forM_ ts \(Ptr tbl') ->
forTable tbl' \e ->
updateTable tbl e
#ifdef DISPLAY_OUTPUT
es <- sortBy (\e e' -> compare (_key e) (_key e')) <$> tableToList tbl
BB.hPutBuilder stdout (displayEntries es)
putChar '\n'
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment