Skip to content

Instantly share code, notes, and snippets.

@DavidBuchanan314
Last active July 1, 2024 00:15
Show Gist options
  • Save DavidBuchanan314/8a6851567ee12ec8c87b2e76f1510275 to your computer and use it in GitHub Desktop.
Save DavidBuchanan314/8a6851567ee12ec8c87b2e76f1510275 to your computer and use it in GitHub Desktop.
Python + pyca/cryptography implementation of https://github.com/C2SP/C2SP/blob/main/XAES-256-GCM.md
import os
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.ciphers import algorithms
from cryptography.hazmat.primitives.kdf.kbkdf import (
CounterLocation, KBKDFCMAC, Mode
)
# https://github.com/C2SP/C2SP/blob/main/XAES-256-GCM.md
class XAES256GCM:
# sizes in bytes
KEY_SIZE = 32
NONCE_SIZE = 24
def __init__(self, key: bytes):
if len(key) != self.KEY_SIZE:
raise ValueError("Invalid key length (expected 32 bytes, 256 bits)")
self._key = key
@classmethod
def generate_key(cls) -> bytes:
return AESGCM.generate_key(bit_length=cls.KEY_SIZE * 8)
@classmethod
def generate_random_nonce(cls) -> bytes:
# nb: AESGCM doesn't have an equivalent method because random nonces would be unsafe there!
return os.urandom(cls.NONCE_SIZE)
def _derive_key(self, nonce: bytes) -> AESGCM:
kdf = KBKDFCMAC(
algorithm=algorithms.AES256,
mode=Mode.CounterMode,
length=32, # output length
rlen=2, # counter length in bytes (16 bits)
llen=None,
location=CounterLocation.BeforeFixed,
label=None,
context=None,
fixed=b"X\x00" + nonce,
)
return kdf.derive(self._key)
def encrypt(self, nonce: bytes, data: bytes, associated_data: bytes) -> bytes:
if len(nonce) != self.NONCE_SIZE:
raise ValueError("Invalid nonce length (expected 24 bytes, 192 bits)")
k, n = self._derive_key(nonce[:12]), nonce[12:]
return AESGCM(k).encrypt(n, data, associated_data)
def decrypt(self, nonce: bytes, data: bytes, associated_data: bytes) -> bytes:
if len(nonce) != self.NONCE_SIZE:
raise ValueError("Invalid nonce length (expected 24 bytes, 192 bits)")
k, n = self._derive_key(nonce[:12]), nonce[12:]
return AESGCM(k).decrypt(n, data, associated_data)
if __name__ == "__main__":
# basic test
k = XAES256GCM.generate_key()
n = XAES256GCM.generate_random_nonce()
c = XAES256GCM(k)
ct = c.encrypt(n, b"hello", b"world")
pt = c.decrypt(n, ct, b"world")
assert(pt == b"hello")
# test vectors from https://github.com/C2SP/C2SP/blob/main/XAES-256-GCM.md
K = bytes.fromhex("0101010101010101010101010101010101010101010101010101010101010101")
N = b"ABCDEFGHIJKLMNOPQRSTUVWX"
PT = b"XAES-256-GCM"
AAD = b""
c = XAES256GCM(K)
ct = c.encrypt(N, PT, AAD)
assert(ct == bytes.fromhex("ce546ef63c9cc60765923609b33a9a1974e96e52daf2fcf7075e2271"))
roundtrip = c.decrypt(N, ct, AAD)
assert(roundtrip == PT)
K = bytes.fromhex("0303030303030303030303030303030303030303030303030303030303030303")
N = b"ABCDEFGHIJKLMNOPQRSTUVWX"
PT = b"XAES-256-GCM"
AAD = b"c2sp.org/XAES-256-GCM"
c = XAES256GCM(K)
ct = c.encrypt(N, PT, AAD)
assert(ct == bytes.fromhex("986ec1832593df5443a179437fd083bf3fdb41abd740a21f71eb769d"))
roundtrip = c.decrypt(N, ct, AAD)
assert(roundtrip == PT)
# randomized test setup
# https://github.com/DavidBuchanan314/ml-kem-stuff/blob/b44c93e048d768f9b96bd82e24ef8d845944a777/shakestream.py#L2
class ShakeStream:
def __init__(self, digestfn) -> None:
# digestfn is anything we can call repeatedly with different lengths
self.digest = digestfn
self.buf = self.digest(32) # arbitrary starting length
self.offset = 0
def read(self, n: int) -> bytes:
# double the buffer size until we have enough
while self.offset + n > len(self.buf):
self.buf = self.digest(len(self.buf) * 2)
res = self.buf[self.offset:self.offset + n]
self.offset += n
return res
from hashlib import shake_128
rng = ShakeStream(shake_128(b"").digest)
accumulator = shake_128()
for _ in range(10_000): # The 1M testcase is painfully slow in Python, but I did run it once and it passed
key = rng.read(XAES256GCM.KEY_SIZE)
nonce = rng.read(XAES256GCM.NONCE_SIZE)
plaintext = rng.read(rng.read(1)[0])
aad = rng.read(rng.read(1)[0])
ciphertext = XAES256GCM(key).encrypt(nonce, plaintext, aad)
accumulator.update(ciphertext)
roundtrip = XAES256GCM(key).decrypt(nonce, ciphertext, aad)
assert(roundtrip == plaintext)
assert(accumulator.digest(32) == bytes.fromhex("e6b9edf2df6cec60c8cbd864e2211b597fb69a529160cd040d56c0c210081939"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment