Last active
July 6, 2017 12:10
-
-
Save snoyberg/20243aae347b38ad09daaf8b129e2efb to your computer and use it in GitHub Desktop.
Streaming public key encryption/decryption
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env stack | |
-- stack --resolver lts-8.12 script | |
{-# OPTIONS_GHC -Wall -Werror #-} | |
import Control.Exception.Safe (MonadThrow, assert, throwM) | |
import Control.Monad.Trans.Class (lift) | |
import qualified Crypto.Cipher.ChaCha as ChaCha | |
import qualified Crypto.Cipher.ChaChaPoly1305 as Cha | |
import qualified Crypto.ECC as ECC | |
import qualified Crypto.Error as CE | |
import Crypto.Hash (SHA512 (..), hashWith) | |
import qualified Crypto.MAC.Poly1305 as Poly1305 | |
import Crypto.PubKey.ECIES (deriveDecrypt, deriveEncrypt) | |
import Crypto.Random (MonadRandom, getRandomBytes) | |
import qualified Data.ByteArray as BA | |
import Data.ByteString (ByteString) | |
import qualified Data.ByteString as B | |
import qualified Data.ByteString.Lazy as BL | |
import Data.Conduit (ConduitM, await, leftover, | |
runConduit, yield, (.|)) | |
import qualified Data.Conduit.Binary as CB | |
import qualified Data.Conduit.List as CL | |
import Data.Proxy (Proxy (..)) | |
import Test.Hspec (hspec, shouldBe) | |
import Test.Hspec.QuickCheck (prop) | |
cf :: MonadThrow m => CE.CryptoFailable a -> m a | |
cf (CE.CryptoPassed x) = return x | |
cf (CE.CryptoFailed e) = throwM e | |
encryptChaChaPoly1305 | |
:: MonadThrow m | |
=> ByteString -- ^ nonce (12 random bytes) | |
-> ByteString -- ^ symmetric key (32 bytes) | |
-> ConduitM ByteString ByteString m () | |
encryptChaChaPoly1305 nonceBS key = do | |
nonce <- cf $ Cha.nonce12 nonceBS | |
state0 <- cf $ Cha.initialize key nonce | |
yield nonceBS | |
let loop state1 = do | |
mbs <- await | |
case mbs of | |
Nothing -> yield $ BA.convert $ Cha.finalize state1 | |
Just bs -> do | |
let (bs', state2) = Cha.encrypt bs state1 | |
yield bs' | |
loop state2 | |
loop $ Cha.finalizeAAD state0 | |
decryptChaChaPoly1305 | |
:: MonadThrow m | |
=> ByteString -- ^ symmetric key (32 bytes) | |
-> ConduitM ByteString ByteString m () | |
decryptChaChaPoly1305 key = do | |
nonceBS <- CB.take 12 | |
nonce <- cf $ Cha.nonce12 $ BL.toStrict nonceBS | |
state0 <- cf $ Cha.initialize key nonce | |
let loop state1 = do | |
ebs <- awaitExcept16 id | |
case ebs of | |
Left final -> | |
case Poly1305.authTag final of | |
CE.CryptoPassed final' | Cha.finalize state1 == final' -> return () | |
_ -> error "Auth didn't match in ChaCha" | |
Right bs -> do | |
let (bs', state2) = Cha.decrypt bs state1 | |
yield bs' | |
loop state2 | |
loop $ Cha.finalizeAAD state0 | |
where | |
awaitExcept16 front = do | |
mbs <- await | |
case mbs of | |
Nothing -> return $ Left $ front B.empty | |
Just bs -> do | |
let bs' = front bs | |
if B.length bs' > 16 | |
then do | |
let (x, y) = B.splitAt (B.length bs' - 16) bs' | |
assert (B.length y == 16) leftover y | |
return $ Right x | |
else awaitExcept16 (B.append bs') | |
getNonceKey :: ECC.SharedSecret -> (ByteString, ByteString) | |
getNonceKey shared = | |
let state1 = ChaCha.initializeSimple $ B.take 40 $ BA.convert $ hashWith SHA512 shared | |
(nonce, state2) = ChaCha.generateSimple state1 12 | |
(key, _) = ChaCha.generateSimple state2 32 | |
in (nonce, key) | |
proxy :: Proxy ECC.Curve_P256R1 | |
proxy = Proxy | |
encryptPublic | |
:: (MonadThrow m, MonadRandom m) | |
=> ECC.Point ECC.Curve_P256R1 | |
-> ConduitM ByteString ByteString m () | |
encryptPublic point = do | |
(point', shared) <- lift $ deriveEncrypt proxy point | |
let (nonce, key) = getNonceKey shared | |
yield $ ECC.encodePoint proxy point' | |
encryptChaChaPoly1305 nonce key | |
decryptPublic | |
:: (MonadThrow m) | |
=> ECC.Scalar ECC.Curve_P256R1 | |
-> ConduitM ByteString ByteString m () | |
decryptPublic scalar = do | |
pointBS <- fmap BL.toStrict $ CB.take 65 -- magic value, known size of point | |
point <- cf $ ECC.decodePoint proxy pointBS | |
let shared = deriveDecrypt proxy point scalar | |
(_nonce, key) = getNonceKey shared | |
decryptChaChaPoly1305 key | |
main :: IO () | |
main = hspec $ do | |
prop "encrypt/decrypt chacha works" $ \octets -> do | |
let chunksIn = map B.pack octets | |
nonce <- getRandomBytes 12 | |
key <- getRandomBytes 32 | |
chunksOut <- runConduit | |
$ mapM_ yield chunksIn | |
.| encryptChaChaPoly1305 nonce key | |
.| decryptChaChaPoly1305 key | |
.| CL.consume | |
BL.fromChunks chunksOut `shouldBe` BL.fromChunks chunksIn | |
prop "encrypt/decrypt public works" $ \octets -> do | |
let chunksIn = map B.pack octets | |
ECC.KeyPair point scalar <- ECC.curveGenerateKeyPair proxy | |
chunksOut <- runConduit | |
$ mapM_ yield chunksIn | |
.| encryptPublic point | |
.| decryptPublic scalar | |
.| CL.consume | |
BL.fromChunks chunksOut `shouldBe` BL.fromChunks chunksIn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment