-
-
Save dminuoso/148ba87e77894d5fcacd6c164976aaee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- | A reverse put monad over a fixed-size bytestring buffer. | |
-- Reverse means this will put things starting from the end. This | |
-- is way more convenient for RADIUS attributes. | |
newtype Put a = Put { unPut :: Buffer -> IO (Buffer, a) } | |
data Buffer = Buffer | |
{ bEnd :: {-# UNPACK #-} !(Ptr Word8) -- ^ end of the buffer (first byte - we are in reverse) | |
, bCur :: {-# UNPACK #-} !(Ptr Word8) -- ^ current position | |
} | |
instance Functor Put where | |
{-# INLINE fmap #-} | |
fmap f (Put g) = Put $ \buf -> do | |
(buf, x) <- g buf | |
pure (buf, f x) | |
instance Applicative Put where | |
{-# INLINE pure #-} | |
pure a = Put $ \buf -> pure (buf, a) | |
{-# INLINE (<*>) #-} | |
Put ff <*> Put fa = Put $ \buf -> do | |
(buf', f) <- ff buf | |
(buf'', a) <- fa buf' | |
pure (buf'', f a) | |
instance Monad Put where | |
{-# INLINE return #-} | |
return = pure | |
{-# INLINE (>>=) #-} | |
fa >>= f = Put $ \buf -> do | |
(buf', a) <- unPut fa buf | |
(buf'', b) <- unPut (f a) buf' | |
pure (buf'', b) | |
-- | Run a 'Put' encoder, but produce Nothing if any encoder exceeded any bounds. | |
runPutMaybe :: Int -> Put a -> Maybe BS.ByteString | |
runPutMaybe l p = unsafeDupablePerformIO ((Just <$> runPutIO l p) `catch` go) | |
where | |
go :: BufferExceeded -> IO (Maybe BS.ByteString) | |
go _ = pure Nothing | |
runPut :: Int -> Put a -> BS.ByteString | |
runPut l p = unsafeDupablePerformIO (runPutIO l p) | |
runPutIO :: Int -> Put a -> IO BS.ByteString | |
runPutIO l (Put f) = do | |
fptr <- mallocPlainForeignPtrBytes l | |
let ptr = unsafeForeignPtrToPtr fptr | |
first = ptr `plusPtr` (l + 1) | |
buf = Buffer { bCur = first -- Set position to the end. | |
, bEnd = ptr | |
} | |
(buf, _) <- f buf | |
let offset = bCur buf `minusPtr` ptr | |
len = first `minusPtr` bCur buf | |
pure (BS.PS fptr offset len) | |
withLengthOf :: Put a -> (Int -> Put b) -> Put b | |
withLengthOf (Put p) f = Put $ \buf -> do | |
(buf', _) <- p buf | |
let written = bCur buf `minusPtr` bCur buf' | |
unPut (f written) buf' | |
fixed :: Show a => BS.FixedPrim a -> a -> Put () | |
fixed fp a = withPtr (BS.size fp) (\ptr -> BS.runF fp a ptr) | |
-- | Helper for ByteString fixed prim builders. | |
-- The function is expected to advance the pointer forwards, while we use `consume` internally | |
-- and translate the buffer movement backwards. | |
withPtr :: Int -> (Ptr Word8 -> IO a) -> Put () | |
withPtr n f = consume n $ \(Buffer e p) -> f p | |
ensure :: T.Text -> Int -> Put a -> Put a | |
ensure label l (Put p) = Put $ \buf -> do | |
(buf', r) <- p buf | |
case compare (bCur buf') (bCur buf `subPtr` l) of | |
EQ | bCur buf' > bEnd buf | |
-> pure (buf', r) | |
| otherwise | |
-> throwIO (BufferExceeded label) | |
LT -> throwIO (BufferExceeded label) | |
GT -> throwIO (MissingInput label) | |
consume :: Int -> (Buffer -> IO a) -> Put () | |
consume len f | len >= 0 = Put $ \buf -> | |
let new = bCur buf `subPtr` len | |
buf' = buf { bCur = new } | |
in if new > bEnd buf then do f buf' | |
pure (buf', ()) | |
else throwIO (BufferExceeded "") | |
| otherwise | |
= error "consume: negative length" | |
{-# INLINE putWord64BE #-} | |
putWord64BE :: Word64 -> Put () | |
putWord64BE = fixed BS.word64BE | |
{-# INLINE putWord32BE #-} | |
putWord32BE :: Word32 -> Put () | |
putWord32BE = fixed BS.word32BE | |
{-# INLINE putWord16BE #-} | |
putWord16BE :: Word16 -> Put () | |
putWord16BE = fixed BS.word16BE | |
{-# INLINE putWord8 #-} | |
putWord8 :: Word8 -> Put () | |
putWord8 = fixed BS.word8 | |
{-# INLINE putInt64BE #-} | |
putInt64BE :: Int64 -> Put () | |
putInt64BE = fixed BS.int64BE | |
{-# INLINE putInt32BE #-} | |
putInt32BE :: Int32 -> Put () | |
putInt32BE = fixed BS.int32BE | |
{-# INLINE putInt16BE #-} | |
putInt16BE :: Int16 -> Put () | |
putInt16BE = fixed BS.int16BE | |
{-# INLINE putInt8 #-} | |
putInt8 :: Int8 -> Put () | |
putInt8 = fixed BS.int8 | |
{-# INLINE putFloatBE #-} | |
putFloatBE :: Float -> Put () | |
putFloatBE = fixed BS.floatBE | |
{-# INLINE putDoubleBE #-} | |
putDoubleBE :: Double -> Put () | |
putDoubleBE = fixed BS.doubleBE | |
{-# INLINE putByteString #-} | |
putByteString :: BS.ByteString -> Put () | |
putByteString bs = let l = BS.length bs | |
in withPtr l $ \dst -> do | |
BS.unsafeUseAsCString bs (\src -> copyBytes dst (castPtr src) l) | |
{-# INLINE subPtr #-} | |
subPtr :: Ptr a -> Int -> Ptr a | |
subPtr ptr a = ptr `plusPtr` negate a |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment