Skip to content

Instantly share code, notes, and snippets.

@NathanHowell
Created May 20, 2012 07:57
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NathanHowell/2757253 to your computer and use it in GitHub Desktop.
Save NathanHowell/2757253 to your computer and use it in GitHub Desktop.
Protocol Buffers via GHC Generics
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverlappingInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Pb where
import Control.Applicative
import Control.Monad
import Data.Bits
import Data.ByteString
import Data.HashMap.Strict as HashMap
import Data.Hex
import Data.Int
import Data.Monoid
import Data.Serialize.Get
import Data.Serialize.Put
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.TypeLevel as Tl
import Data.Word
import Unsafe.Coerce
import GHC.Generics
type Tag = Word32
data Field where
VarintField :: Tag -> Word64 -> Field
Fixed64Field :: Tag -> Word64 -> Field
DelimitedField :: Tag -> ByteString -> Field
StartField :: Tag -> Field
EndField :: Tag -> Field
Fixed32Field :: Tag -> Word32 -> Field
deriving Show
getVarintPrefixedBS :: Get ByteString
getVarintPrefixedBS = do
length <- getVarInt
getBytes length
getField :: Get Field
getField = do
wireTag :: Word64 <- getVarInt
let tag = fromIntegral $ wireTag `shiftR` 3
case wireTag .&. 7 of
0 -> VarintField tag <$> getVarInt
1 -> Fixed64Field tag <$> getWord64le
2 -> DelimitedField tag <$> getVarintPrefixedBS
3 -> return $! StartField tag
4 -> return $! EndField tag
5 -> Fixed32Field tag <$> getWord32le
getVarInt :: (Integral a, Bits a) => Get a
getVarInt = go 0 0 where
go n !val = do
b <- getWord8
if testBit b 7
then go (n+7) (val .|. ((fromIntegral (b .&. 0x7F)) `shiftL` n))
else return $! val .|. ((fromIntegral b) `shiftL` n)
-- This can be used on any Integral type and is needed for signed types; unsigned can use putVarUInt below.
-- This has been changed to handle only up to 64 bit integral values (to match documentation).
{-# INLINE putVarSInt #-}
putVarSInt :: (Integral a, Bits a) => a -> Put
putVarSInt bIn =
case compare bIn 0 of
LT -> let -- upcast to 64 bit to match documentation of 10 bytes for all negative values
b = fromIntegral bIn
len = 10 -- (pred 10)*7 < 64 <= 10*7
last'Mask = 1 -- pred (1 `shiftL` 1)
go :: Int64 -> Int -> Put
go !i 1 = putWord8 (fromIntegral (i .&. last'Mask))
go !i n = putWord8 (fromIntegral (i .&. 0x7F) .|. 0x80) >> go (i `shiftR` 7) (pred n)
in go b len
EQ -> putWord8 0
GT -> putVarUInt bIn
-- This should be used on unsigned Integral types only (not checked)
{-# INLINE putVarUInt #-}
putVarUInt :: (Integral a, Bits a) => a -> Put
putVarUInt i
| i < 0x80 = putWord8 (fromIntegral i)
| otherwise = putWord8 (fromIntegral (i .&. 0x7F) .|. 0x80) >> putVarUInt (i `shiftR` 7)
fieldTag :: Field -> Tag
fieldTag f = case f of
VarintField t _ -> t
Fixed64Field t _ -> t
DelimitedField t _ -> t
StartField t -> t
EndField t -> t
Fixed32Field t _ -> t
decodeMessage :: (Decode (Rep a), Generic a) => Get a
decodeMessage = fmap to (decode =<< go HashMap.empty) where
go msg = do
mfield <- Just `fmap` getField <|> return Nothing
case mfield of
Just v -> go $! HashMap.insertWith (++) (fieldTag v) [v] msg
Nothing -> return msg
class Wire a where
decodeWire :: Field -> a
encodeWire :: Tag -> a -> Put
sizeWire :: a -> Int
deriving instance Wire a => Wire (First a)
deriving instance Wire a => Wire (Last a)
deriving instance Wire a => Wire (Product a)
deriving instance Wire a => Wire (Sum a)
instance Wire a => Wire (Maybe a) where
decodeWire = Just . decodeWire
encodeWire t (Just a) = encodeWire t a
encodeWire t Nothing = return ()
instance Wire Int32 where
decodeWire (VarintField _ val) = fromIntegral val
encodeWire t val = putVarSInt val
instance Wire Int64 where
decodeWire (VarintField _ val) = fromIntegral val
instance Wire Word32 where
decodeWire (VarintField _ val) = fromIntegral val
instance Wire Word64 where
decodeWire (VarintField _ val) = val
instance Wire (Signed Int32) where
decodeWire (VarintField _ val) = Signed (zzDecode32 (fromIntegral val))
instance Wire (Signed Int64) where
decodeWire (VarintField _ val) = Signed (zzDecode64 (fromIntegral val))
instance Wire (Fixed Int32) where
decodeWire (Fixed32Field _ val) = Fixed (fromIntegral val)
instance Wire (Fixed Int64) where
decodeWire (Fixed64Field _ val) = Fixed (fromIntegral val)
instance Wire (Fixed Word32) where
decodeWire (Fixed32Field _ val) = Fixed val
instance Wire (Fixed Word64) where
decodeWire (Fixed64Field _ val) = Fixed val
instance Wire Bool where
decodeWire (VarintField _ val) = val /= 0
instance Wire Float where
decodeWire (Fixed32Field _ val) = unsafeCoerce val
instance Wire Double where
decodeWire (Fixed64Field _ val) = unsafeCoerce val
instance Wire ByteString where
decodeWire (DelimitedField _ bs) = bs
instance Wire T.Text where
decodeWire (DelimitedField _ bs) = T.decodeUtf8 bs
instance (Decode (Rep a), Generic a) => Wire a where
decodeWire (DelimitedField _ bs) = val where
Right val = runGet decodeMessage bs
-- field rules
newtype Required n a = Required a
deriving instance Bits a => Bits (Required n a)
deriving instance Bounded a => Bounded (Required n a)
deriving instance Enum a => Enum (Required n a)
deriving instance Eq a => Eq (Required n a)
deriving instance Floating a => Floating (Required n a)
deriving instance Fractional a => Fractional (Required n a)
deriving instance Integral a => Integral (Required n a)
deriving instance Monoid a => Monoid (Required n a)
deriving instance Num a => Num (Required n a)
deriving instance Ord a => Ord (Required n a)
deriving instance Real a => Real (Required n a)
deriving instance RealFloat a => RealFloat (Required n a)
deriving instance RealFrac a => RealFrac (Required n a)
deriving instance Show a => Show (Required n a)
newtype Optional n a = Optional a
deriving instance Bits a => Bits (Optional n a)
deriving instance Bounded a => Bounded (Optional n a)
deriving instance Enum a => Enum (Optional n a)
deriving instance Eq a => Eq (Optional n a)
deriving instance Floating a => Floating (Optional n a)
deriving instance Fractional a => Fractional (Optional n a)
deriving instance Integral a => Integral (Optional n a)
deriving instance Monoid a => Monoid (Optional n a)
deriving instance Num a => Num (Optional n a)
deriving instance Ord a => Ord (Optional n a)
deriving instance Real a => Real (Optional n a)
deriving instance RealFloat a => RealFloat (Optional n a)
deriving instance RealFrac a => RealFrac (Optional n a)
deriving instance Show a => Show (Optional n a)
newtype Packed n a = Packed a
deriving instance Bits a => Bits (Packed n a)
deriving instance Bounded a => Bounded (Packed n a)
deriving instance Enum a => Enum (Packed n a)
deriving instance Eq a => Eq (Packed n a)
deriving instance Floating a => Floating (Packed n a)
deriving instance Fractional a => Fractional (Packed n a)
deriving instance Integral a => Integral (Packed n a)
deriving instance Monoid a => Monoid (Packed n a)
deriving instance Num a => Num (Packed n a)
deriving instance Ord a => Ord (Packed n a)
deriving instance Real a => Real (Packed n a)
deriving instance RealFloat a => RealFloat (Packed n a)
deriving instance RealFrac a => RealFrac (Packed n a)
deriving instance Show a => Show (Packed n a)
-- Integer encoding annotations
newtype Signed a = Signed a
deriving instance Bits a => Bits (Signed a)
deriving instance Bounded a => Bounded (Signed a)
deriving instance Enum a => Enum (Signed a)
deriving instance Eq a => Eq (Signed a)
deriving instance Floating a => Floating (Signed a)
deriving instance Fractional a => Fractional (Signed a)
deriving instance Integral a => Integral (Signed a)
deriving instance Monoid a => Monoid (Signed a)
deriving instance Num a => Num (Signed a)
deriving instance Ord a => Ord (Signed a)
deriving instance Real a => Real (Signed a)
deriving instance RealFloat a => RealFloat (Signed a)
deriving instance RealFrac a => RealFrac (Signed a)
deriving instance Show a => Show (Signed a)
newtype Fixed a = Fixed a
deriving instance Bits a => Bits (Fixed a)
deriving instance Bounded a => Bounded (Fixed a)
deriving instance Enum a => Enum (Fixed a)
deriving instance Eq a => Eq (Fixed a)
deriving instance Floating a => Floating (Fixed a)
deriving instance Fractional a => Fractional (Fixed a)
deriving instance Integral a => Integral (Fixed a)
deriving instance Monoid a => Monoid (Fixed a)
deriving instance Num a => Num (Fixed a)
deriving instance Ord a => Ord (Fixed a)
deriving instance Real a => Real (Fixed a)
deriving instance RealFloat a => RealFloat (Fixed a)
deriving instance RealFrac a => RealFrac (Fixed a)
deriving instance Show a => Show (Fixed a)
class Decode (f :: * -> *) where
decode :: (Applicative m, Monad m) => HashMap Tag [Field] -> m (f a)
instance Decode a => Decode (M1 i c a) where
decode = fmap M1 . decode
instance (Wire a, Monoid a, Tl.Nat n) => Decode (K1 i (Optional n a)) where
decode msg =
pure $! case HashMap.lookup (fromIntegral (Tl.toInt (undefined :: n))) msg of
Just val -> K1 . Optional . mconcat $ fmap decodeWire val
Nothing -> K1 mempty
instance (Wire a, Monoid a, Tl.Nat n) => Decode (K1 i (Required n a)) where
decode msg =
case HashMap.lookup (fromIntegral (Tl.toInt (undefined :: n))) msg of
Just val -> pure . K1 . Required . mconcat $ fmap decodeWire val
Nothing -> fail "required field not found"
{-
instance (Wire a, Monoid a, Tl.Nat n) => Decode (K1 i (Packed n a)) where
decode = error "packed fields are not implemented"
-}
instance (Decode a, Decode b) => Decode (a :*: b) where
decode msg = liftA2 (:*:) (decode msg) (decode msg)
-- Taken from google's code, but I had to explcitly add fromIntegral in the right places:
zzEncode32 :: Int32 -> Word32
zzEncode32 x = fromIntegral ((x `shiftL` 1) `xor` (x `shiftR` 31))
zzEncode64 :: Int64 -> Word64
zzEncode64 x = fromIntegral ((x `shiftL` 1) `xor` (x `shiftR` 63))
zzDecode32 :: Word32 -> Int32
zzDecode32 w = (fromIntegral (w `shiftR` 1)) `xor` (negate (fromIntegral (w .&. 1)))
zzDecode64 :: Word64 -> Int64
zzDecode64 w = (fromIntegral (w `shiftR` 1)) `xor` (negate (fromIntegral (w .&. 1)))
data TestRec = TestRec
{ field1 :: Required Tl.D1 (Last Int64)
, field2 :: Optional Tl.D2 (Last T.Text)
, field3 :: Optional Tl.D3 (Last Int32)
-- , field3 :: Optional Tl.D3 (Signed Int)
-- , field4 :: Optional Tl.D4 Double
} deriving (Generic, Show)
main :: IO ()
main = do
let val1, val2, val3 :: Either String TestRec
val1 = runGet decodeMessage =<< unhex "089601120774657374696e67"
val2 = runGet decodeMessage =<< unhex "089601189701"
val3 = runGet decodeMessage =<< unhex "089601"
print val1
print val2
print val3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment