Skip to content

Instantly share code, notes, and snippets.

@noughtmare
Last active March 14, 2024 11:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save noughtmare/ab64aa24e5a106142164c91ea1ac5af7 to your computer and use it in GitHub Desktop.
Save noughtmare/ab64aa24e5a106142164c91ea1ac5af7 to your computer and use it in GitHub Desktop.
1brc solution based on @Bodigrim's solution, but with linearly probed primitive hash table
#!/usr/bin/env cabal
{- cabal:
build-depends: base >= 4.19, bytestring, primitive >= 0.9, mmap
default-language: GHC2021
ghc-options: -Wall -O2 -fllvm
-}
{-# LANGUAGE ExtendedLiterals #-}
{-# LANGUAGE MagicHash, UnboxedTuples, UnliftedDatatypes #-}
import Data.ByteString (ByteString)
import Data.ByteString qualified as B
import Data.ByteString.Char8 qualified as C8
import Data.ByteString.Unsafe qualified as B
import Text.Printf
import GHC.Word
import qualified Data.Primitive.PrimArray as A
import Foreign
import Data.Primitive
import GHC.Exts
import Data.ByteString.Internal ( ByteString(BS), accursedUnutterablePerformIO)
import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr)
import System.IO.MMap
-- Kovacshash
salt :: Word
salt = 3032525626373534813
unW# :: Word -> Word#
unW# (W# x) = x
foldedMul :: Word# -> Word# -> Word#
foldedMul x y = case timesWord2# x y of (# hi, lo #) -> xor# hi lo
combine :: Word# -> Word# -> Word#
combine x y = foldedMul (xor# x y) 11400714819323198549##
goHash :: Addr# -> Int -> Word# -> Word#
goHash p l acc
| l >= 8 = let w = indexWordOffAddr# p 0# in goHash (plusAddr# p 8#) (l - 8) (combine w acc)
| l >= 4 = let w = indexWord32OffAddr# p 0# in goHash (plusAddr# p 4#) (l - 4) (combine (word32ToWord# w) acc)
| l == 0 = acc
| otherwise = let w = indexWord8OffAddr# p 0# in goHash (plusAddr# p 1#) (l - 1) (combine (word8ToWord# w) acc)
-- Unsafe ByteString
data UnsafeByteString = MkUBS {-# UNPACK #-} !(Ptr Word8) !Int
instance Eq UnsafeByteString where
(==) = eqUBS
goEqUBS :: Addr# -> Addr# -> Int -> Int#
goEqUBS p p' l
| l >= 8 = case eqWord# (indexWordOffAddr# p 0#) (indexWordOffAddr# p' 0#) of
1# -> goEqUBS (plusAddr# p 8#) (plusAddr# p' 8#) (l - 8)
_ -> 0#
| l >= 4 = case eqWord32# (indexWord32OffAddr# p 0#) (indexWord32OffAddr# p' 0#) of
1# -> goEqUBS (plusAddr# p 4#) (plusAddr# p' 4#) (l - 4)
_ -> 0#
| l == 0 = 1#
| True = case eqWord8# (indexWord8OffAddr# p 0#) (indexWord8OffAddr# p' 0#) of
1# -> goEqUBS (plusAddr# p 1#) (plusAddr# p' 1#) (l - 1)
_ -> 0#
eqUBS :: UnsafeByteString -> UnsafeByteString -> Bool
eqUBS (MkUBS (Ptr p) l) (MkUBS (Ptr p') l')
| l == l' = isTrue# (goEqUBS p p' l)
| otherwise = False
fromByteStringUBS :: ByteString -> UnsafeByteString
fromByteStringUBS (BS fp l) = MkUBS (unsafeForeignPtrToPtr fp) l
toByteStringUBS :: UnsafeByteString -> ByteString
toByteStringUBS (MkUBS p l) = BS (accursedUnutterablePerformIO (newForeignPtr_ p)) l
hashUBS :: UnsafeByteString -> Int
-- hashUBS (MkUBS p l) = fromIntegral (accursedUnutterablePerformIO (murmurHash3_x64_128 p l 1))
hashUBS (MkUBS (Ptr addr) l) = I# (word2Int# (goHash addr l (unW# salt)))
emptyUBS :: UnsafeByteString
emptyUBS = MkUBS nullPtr 0
isEmptyUBS :: UnsafeByteString -> Bool
isEmptyUBS (MkUBS _ 0) = True
isEmptyUBS (MkUBS _ _) = False
-- Rest
type Station = UnsafeByteString
mkStation :: ByteString -> Station
mkStation = fromByteStringUBS
data Entry = Entry
{ _station :: {-# UNPACK #-} !Station
, _temperature :: !Int
}
-- Bayawan;-21.1
-- Andranomenatsa;-1.2
-- Benton Harbor;36.2
-- Taulahā;0.6
parseLine :: ByteString -> Entry
parseLine xs = case x4 of
W8# 59#Word8 -- ord ';'
-> Entry (mkStation $ B.unsafeTake (l - 4) xs) (x3' * 10 + x1' - 528)
W8# 45#Word8 -- ord '-'
-> Entry (mkStation $ B.unsafeTake (l - 5) xs) (528 - x3' * 10 - x1')
_ -> case x5 of
W8# 59#Word8 -- ord ';'
-> Entry (mkStation $ B.unsafeTake (l - 5) xs) (x4' * 100 + x3' * 10 + x1' - 5328)
_ -- ord '-'
-> Entry (mkStation $ B.unsafeTake (l - 6) xs) (5328 - x4' * 100 - x3' * 10 - x1')
where
l = B.length xs
x1 = B.unsafeIndex xs (l - 1) -- last digit
x3 = B.unsafeIndex xs (l - 3) -- another digit
x4 = B.unsafeIndex xs (l - 4) -- digit or sign or semicolon
x5 = B.unsafeIndex xs (l - 5) -- sign or semicolon
x1' = fromIntegral x1
x3' = fromIntegral x3
x4' = fromIntegral x4
data Quartet = Quartet
{ _min :: !Int
, _total :: !Int
, _cnt :: !Int
, _max :: !Int
} deriving (Eq)
mkQuartet :: Int -> Quartet
mkQuartet x = Quartet x x 1 x
updateQuartet :: Int -> Quartet -> Quartet
updateQuartet x (Quartet a b c d) = Quartet (min a x) (b + x) (c + 1) (max d x)
instance Semigroup Quartet where
Quartet a b c d <> Quartet a' b' c' d' =
Quartet (min a a') (b + b') (c + c') (max d d')
data Row = Row {-# UNPACK #-} !Station {-# UNPACK #-} !Quartet
instance Prim Row where
sizeOf# _ = 6# *# 8#
alignment# _ = 8#
indexByteArray# ba i = Row (MkUBS p l) (Quartet a b c d) where
p = indexByteArray# ba (6# *# i)
l = indexByteArray# ba (6# *# i +# 1#)
a = indexByteArray# ba (6# *# i +# 2#)
b = indexByteArray# ba (6# *# i +# 3#)
c = indexByteArray# ba (6# *# i +# 4#)
d = indexByteArray# ba (6# *# i +# 5#)
{-# INLINE indexByteArray# #-}
readByteArray# mba i s0 = (# s6 , Row (MkUBS p l) (Quartet a b c d) #) where
!(# s1 , p #) = readByteArray# mba (6# *# i) s0
!(# s2 , l #) = readByteArray# mba (6# *# i +# 1#) s1
!(# s3 , a #) = readByteArray# mba (6# *# i +# 2#) s2
!(# s4 , b #) = readByteArray# mba (6# *# i +# 3#) s3
!(# s5 , c #) = readByteArray# mba (6# *# i +# 4#) s4
!(# s6 , d #) = readByteArray# mba (6# *# i +# 5#) s5
{-# INLINE readByteArray# #-}
writeByteArray# mba i (Row (MkUBS p l) (Quartet a b c d)) s0 = s6 where
s1 = writeByteArray# mba (6# *# i) p s0
s2 = writeByteArray# mba (6# *# i +# 1#) l s1
s3 = writeByteArray# mba (6# *# i +# 2#) a s2
s4 = writeByteArray# mba (6# *# i +# 3#) b s3
s5 = writeByteArray# mba (6# *# i +# 4#) c s4
s6 = writeByteArray# mba (6# *# i +# 5#) d s5
{-# INLINE writeByteArray# #-}
indexOffAddr# addr i = Row (MkUBS p l) (Quartet a b c d) where
p = indexOffAddr# addr (6# *# i)
l = indexOffAddr# addr (6# *# i +# 1#)
a = indexOffAddr# addr (6# *# i +# 2#)
b = indexOffAddr# addr (6# *# i +# 3#)
c = indexOffAddr# addr (6# *# i +# 4#)
d = indexOffAddr# addr (6# *# i +# 5#)
{-# INLINE indexOffAddr# #-}
readOffAddr# addr i s0 = (# s6 , Row (MkUBS p l) (Quartet a b c d) #) where
!(# s1 , p #) = readOffAddr# addr (6# *# i) s0
!(# s2 , l #) = readOffAddr# addr (6# *# i +# 1#) s1
!(# s3 , a #) = readOffAddr# addr (6# *# i +# 2#) s2
!(# s4 , b #) = readOffAddr# addr (6# *# i +# 3#) s3
!(# s5 , c #) = readOffAddr# addr (6# *# i +# 4#) s4
!(# s6 , d #) = readOffAddr# addr (6# *# i +# 5#) s5
{-# INLINE readOffAddr# #-}
writeOffAddr# addr i (Row (MkUBS p l) (Quartet a b c d)) s0 = s6 where
s1 = writeOffAddr# addr (6# *# i) p s0
s2 = writeOffAddr# addr (6# *# i +# 1#) l s1
s3 = writeOffAddr# addr (6# *# i +# 2#) a s2
s4 = writeOffAddr# addr (6# *# i +# 3#) b s3
s5 = writeOffAddr# addr (6# *# i +# 4#) c s4
s6 = writeOffAddr# addr (6# *# i +# 5#) d s5
{-# INLINE writeOffAddr# #-}
data Table = MkTable !(A.MutablePrimArray RealWorld Row)
insert :: Table -> Entry -> IO ()
insert (MkTable arr) (Entry name t) = do
let i0 = -- traceShowId $
hashUBS name .&. mask
Row s q <- A.readPrimArray arr i0
if isEmptyUBS s
then -- trace "done1" $
A.writePrimArray arr i0 (Row name (mkQuartet t))
else if s == name
then -- trace "done2" $
A.writePrimArray arr i0 (Row name (updateQuartet t q))
else go ((i0 + 1) .&. mask)
where
go i -- | traceShow i True
= do
Row s q <- A.readPrimArray arr i
if isEmptyUBS s
then -- trace "done3" $
A.writePrimArray arr i (Row name (mkQuartet t))
else if s == name
then -- trace "done4"
A.writePrimArray arr i (Row name (updateQuartet t q))
else go ((i + 1) .&. mask)
mask = 0xffff
newTable :: Int -> IO Table
newTable l = do
arr <- A.newPrimArray l
A.setPrimArray arr 0 l (Row emptyUBS (Quartet 0 0 0 0))
pure (MkTable arr)
tableToList :: Table -> IO [Row]
tableToList (MkTable arr) = do
sz <- A.getSizeofMutablePrimArray arr
go sz 0
where
go :: Int -> Int -> IO [Row]
go sz i
| i < sz = do
r@(Row ubs _) <- A.readPrimArray arr i
if isEmptyUBS ubs
then go sz (i + 1)
else (r :) <$> go sz (i + 1)
| otherwise = pure []
mylines :: ByteString -> [ByteString]
mylines ps0 = build (\cons nil ->
let
go ps
| C8.null ps = nil
| otherwise = case search ps of
Nothing -> ps `cons` nil
Just n -> C8.take n ps `cons` go (C8.drop (n+1) ps)
where search = C8.elemIndex '\n'
in
go ps0)
{-# INLINE mylines #-}
parse :: Table -> ByteString -> IO ()
parse tab xs = mapM_ (insert tab . parseLine) $ mylines xs where
aggregate :: [Row] -> ByteString
aggregate m = C8.cons '{' (C8.snoc (B.drop 2 (foldMap go m)) '}')
where
go :: Row -> ByteString
go (Row ss (Quartet a b c d)) = C8.pack ", " <> toByteStringUBS ss <> C8.pack
(printf "=%.1f/%.1f/%.1f" (fromIntegral a / 10 :: Double) (fromIntegral b / (fromIntegral c * 10) :: Double) (fromIntegral d / 10 :: Double))
main :: IO ()
main = do
cnt@(BS fp _) <- mmapFileByteString "data/measurements.txt" Nothing
withForeignPtr fp $ \_ -> do
tab <- newTable 0x10000
parse tab cnt
l <- tableToList tab
C8.putStrLn $ aggregate l
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment