Skip to content

Instantly share code, notes, and snippets.

@Profpatsch
Last active July 27, 2021 22:59
Embed
What would you like to do?
dumb BitString wrapper around ByteString that allows to slice it on bits
{-# LANGUAGE TypeApplications, ExplicitForAll, ScopedTypeVariables, BinaryLiterals, NumericUnderscores, TupleSections, ExistentialQuantification, KindSignatures, DataKinds, MultiWayIf, TypeFamilies, ConstraintKinds, TypeOperators, DerivingStrategies, GeneralizedNewtypeDeriving, InstanceSigs, MultiParamTypeClasses, FlexibleInstances #-}
module Main where
import qualified Data.Bits as Bits
import Data.Word
import qualified Data.List as List
import qualified Data.Text as Text
import qualified Data.ByteString as Bytes
import Data.ByteString (ByteString)
import Data.Text (Text)
import Debug.Trace
import Data.Bifunctor (first)
import GHC.TypeLits
import Data.Proxy
import Data.Function ((&))
import qualified Data.Ord as Ord
import GHC.Stack
-- | A unsigned word with @bits@.
-- Existential type, so the internal representation has to be bigger than what is read.
-- Will check real length on construction.
data WordN (bits :: Nat) = forall w. (Bits.FiniteBits w, Integral w) => WordN w
instance forall nat. (KnownNat nat, Bigger Word (WordN nat)) => Show (WordN (nat ::Nat)) where
show wn = show $ toIntegral @(WordN nat) @Word wn
-- Size of a word, how many bits this type can fit unsigned
type family BitSize word :: Nat
type instance BitSize Word8 = 8
type instance BitSize Word16 = 16
type instance BitSize Word32 = 32
type instance BitSize Word64 = 64
type instance BitSize Word = 64
-- this is a conservative estimate to convert from a word into an int. Int is defined to be 64 bits signed
type instance BitSize Int = 63
type instance BitSize (WordN (bits :: Nat)) = bits
type Bigger bigger small =
( (BitSize small) <= (BitSize bigger))
type BiggerIntegral bigger small =
( Bigger bigger small
, Integral small
, Integral bigger )
intoBiggerWordLE :: BiggerIntegral bigger small => small -> bigger
intoBiggerWordLE = fromIntegral
class ToIntegral n i where
toIntegral :: Integral i => n -> i
newtype AnIntegral i = AnIntegral i
deriving newtype (Eq, Ord, Num, Real, Enum, Integral)
instance Integral i => ToIntegral (AnIntegral i) i where
toIntegral = fromIntegral
instance forall i nat. (Num i, Bigger i (WordN nat)) => ToIntegral (WordN nat) i where
toIntegral :: WordN nat -> i
toIntegral (WordN n) = fromIntegral n
fromBiggerWordLEUnsafe :: BiggerIntegral bigger small => bigger -> small
fromBiggerWordLEUnsafe = fromIntegral
-- | Construct a 'WordN' from a given Word, checks that it fits into @bits@.
wordN :: forall bits w a.
( HasCallStack
, Bigger w (WordN bits)
, KnownNat bits, Bits.FiniteBits w, Integral w, Show w)
=> w -> WordN bits
wordN w =
let
typeSize = Bits.finiteBitSize w
actualSize = typeSize - (Bits.countLeadingZeros w)
wordNSize = fromInteger $ natVal $ Proxy @bits
in if
| actualSize > wordNSize ->
error $ "wordN: " <> show w <> " has more bits than are allowed in a WordN of maximal " <> show wordNSize <> " bits"
| otherwise -> WordN w
wordNBits :: forall bits. KnownNat bits => WordN bits -> Integer
wordNBits wn = natVal (Proxy @bits)
fromWordN :: forall bits w. (Integral w, KnownNat bits, Bigger w (WordN bits)) => WordN bits -> w
fromWordN wn = toIntegral wn
type BitLength = Word8
-- A collection of bits whose length is smaller than @a@.
-- They will always be aligned little-endian in the @a@.
data Bits a = Bits BitLength a
deriving Show
bitsNull (Bits len _) = len == 0
emptyBits = Bits 0 Bits.zeroBits
data BitString = BitString (Bits Word64) ByteString
deriving Show
printBitString (BitString bits bs) = printBits bits ++ "<>" ++ foldMap printByte (Bytes.unpack bs)
bitString :: ByteString -> BitString
bitString bs = BitString emptyBits bs
takeAllNBits :: (KnownNat nat, nat <= 64) => BitString -> [WordN nat]
takeAllNBits = List.unfoldr takeNBits
takeAllNBitsIntermediate :: (KnownNat nat, nat <= 64) => BitString -> [(WordN nat, BitString)]
takeAllNBitsIntermediate = List.unfoldr (intermediate takeNBits)
intermediate :: (a -> Maybe (b, a)) -> a -> Maybe ((b, a), a)
intermediate f a = (\(b, a') -> ((b, a'), a')) <$> f a
takeNBits :: forall nat. (HasCallStack, KnownNat nat, nat <= 64) => BitString -> Maybe (WordN nat, BitString)
takeNBits (BitString bits@(Bits remaining _) bs) =
if neededExtraBits > 0
then
-- check whether there’s anything left to return
case (bitsNull bits, Bytes.null bs) of
(True, True) -> Nothing
(False, True) ->
let (word, rest) = getWordNBits bits
in Just $ (word, BitString rest bs)
(_, False) ->
let
(readBits, extraWord, bs') = readW64LE bs
((Bits _ needed), rest) = splitByteLE neededExtraBits (Bits readBits extraWord)
in Just
-- we prepend the bits we already had by pushing them into the new bits word
-- this will be exactly what is needed to fill the wordN
( wordN (pushBitsLE needed bits)
, BitString rest bs' )
-- just use what’s in the bits we already have
else
let (word, rest) = getWordNBits bits
in Just (word, BitString rest bs)
where
-- Make sure we only put @nat@ bits into the WordN
getWordNBits :: Bits Word64 -> (WordN nat, Bits Word64)
getWordNBits (Bits len bit) =
let (Bits _ word, rest) = splitByteLE (min needBits len) bits
in (wordN word, rest)
neededExtraBits, needBits :: BitLength
needBits = fromInteger $ natVal (Proxy @nat)
neededExtraBits = needBits `subMin0` remaining
-- | Read the beginning of a byteString into a Word64, low endian first.
readW64LE :: ByteString -> (BitLength, Word64, ByteString)
readW64LE bs =
let bsl = Bytes.unpack bs
in bsl
& take 8
& List.reverse
& List.foldl' (\(w64, read) w8 -> (pushBitsLE w64 (Bits 8 w8), read+8)) (Bits.zeroBits :: Word64, 0)
& \(res, readBits) -> (readBits, res, Bytes.pack (List.drop 8 bsl))
-- | subtraction on bit lengths, that won’t overflow if b is bigger than a
subMin0 :: BitLength -> BitLength -> BitLength
subMin0 a b = if a <= b then 0 else a - b
-- push bits into the Word, from the "short end".
-- if the word doesn’t have space for the bits, its big end bits will be pushed out!
pushBitsLE :: BiggerIntegral word bits => (Bits.Bits word, Bits.Bits bits) => word -> Bits bits -> word
pushBitsLE word (Bits count bits) =
-- shift the w8 bits to make space for the Bits
Bits.shiftL word (intoBiggerWordLE count)
-- and then combine with the Bits, which will only have bits up till there.
Bits..|. (intoBiggerWordLE bits)
-- combines take and drop
splitByteLE :: (Num word, Bits.Bits word) => BitLength -> Bits word -> (Bits word, Bits word)
splitByteLE at bits@(Bits _ word) =
( takeBitsLE at word
, dropBitsLE at bits )
-- take @no@ bits from the small end of the word.
takeBitsLE :: (Num word, Bits.Bits word) => BitLength -> word -> Bits word
takeBitsLE no w8 =
-- && the right bits with …0111… bitmask
Bits no (((2^no)-1) Bits..&. w8)
-- drop @no@ bits from the “small end” of the word. The remaining bits are taking their place.
-- If more bits are dropped than are in bits, we get an empty bits.
dropBitsLE :: Bits.Bits word => BitLength -> Bits word -> Bits word
dropBitsLE no (Bits len word) =
Bits
-- amount of remaining bits
(len `subMin0` no)
-- shift the remaining bytes to the right.
(Bits.shiftR word (intoBiggerWordLE no))
newtype WordList = WordList [Text]
deriving Show
-- | Parse the EFF large word list at https://www.eff.org/files/2016/07/18/eff_large_wordlist.txt
-- with at least 2^12 words in it
parseWordList :: Text -> WordList
parseWordList list = WordList $ map getWord lines
where
getWord = List.last . Text.split (== '\t')
lines = Text.lines list
word12ToTextWord :: WordList -> WordN 12 -> Text
word12ToTextWord (WordList l) wn = l !! toIntegral wn
bytesToTextWord :: WordList -> ByteString -> [Text]
bytesToTextWord wl bs = map (word12ToTextWord wl) $ takeAllNBits (bitString bs)
printBitsGeneric :: Bits.FiniteBits a => BitLength -> String -> a -> String
printBitsGeneric length prefix b =
prefix <>
(concatMap (\i -> if Bits.testBit b i then "1" else "0")
(take (fromIntegral length) (List.iterate' (+1) 0)))
printByte :: Word8 -> String
printByte = printBitsGeneric 8 "b"
printWordN :: forall nat. (KnownNat nat, Bigger Word64 (WordN nat)) => WordN nat -> String
printWordN wn =
printBitsGeneric (fromIntegral bits) (show bits <> ":") (fromWordN @nat @Word64 wn)
where bits = wordNBits wn
printBits (Bits len w8) = printBitsGeneric len "bits:" w8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment