Last active
July 1, 2024 00:15
-
-
Save DavidBuchanan314/8a6851567ee12ec8c87b2e76f1510275 to your computer and use it in GitHub Desktop.
Python + pyca/cryptography implementation of https://github.com/C2SP/C2SP/blob/main/XAES-256-GCM.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from cryptography.hazmat.primitives.ciphers.aead import AESGCM | |
from cryptography.hazmat.primitives.ciphers import algorithms | |
from cryptography.hazmat.primitives.kdf.kbkdf import ( | |
CounterLocation, KBKDFCMAC, Mode | |
) | |
# https://github.com/C2SP/C2SP/blob/main/XAES-256-GCM.md | |
class XAES256GCM: | |
# sizes in bytes | |
KEY_SIZE = 32 | |
NONCE_SIZE = 24 | |
def __init__(self, key: bytes): | |
if len(key) != self.KEY_SIZE: | |
raise ValueError("Invalid key length (expected 32 bytes, 256 bits)") | |
self._key = key | |
@classmethod | |
def generate_key(cls) -> bytes: | |
return AESGCM.generate_key(bit_length=cls.KEY_SIZE * 8) | |
@classmethod | |
def generate_random_nonce(cls) -> bytes: | |
# nb: AESGCM doesn't have an equivalent method because random nonces would be unsafe there! | |
return os.urandom(cls.NONCE_SIZE) | |
def _derive_key(self, nonce: bytes) -> AESGCM: | |
kdf = KBKDFCMAC( | |
algorithm=algorithms.AES256, | |
mode=Mode.CounterMode, | |
length=32, # output length | |
rlen=2, # counter length in bytes (16 bits) | |
llen=None, | |
location=CounterLocation.BeforeFixed, | |
label=None, | |
context=None, | |
fixed=b"X\x00" + nonce, | |
) | |
return kdf.derive(self._key) | |
def encrypt(self, nonce: bytes, data: bytes, associated_data: bytes) -> bytes: | |
if len(nonce) != self.NONCE_SIZE: | |
raise ValueError("Invalid nonce length (expected 24 bytes, 192 bits)") | |
k, n = self._derive_key(nonce[:12]), nonce[12:] | |
return AESGCM(k).encrypt(n, data, associated_data) | |
def decrypt(self, nonce: bytes, data: bytes, associated_data: bytes) -> bytes: | |
if len(nonce) != self.NONCE_SIZE: | |
raise ValueError("Invalid nonce length (expected 24 bytes, 192 bits)") | |
k, n = self._derive_key(nonce[:12]), nonce[12:] | |
return AESGCM(k).decrypt(n, data, associated_data) | |
if __name__ == "__main__": | |
# basic test | |
k = XAES256GCM.generate_key() | |
n = XAES256GCM.generate_random_nonce() | |
c = XAES256GCM(k) | |
ct = c.encrypt(n, b"hello", b"world") | |
pt = c.decrypt(n, ct, b"world") | |
assert(pt == b"hello") | |
# test vectors from https://github.com/C2SP/C2SP/blob/main/XAES-256-GCM.md | |
K = bytes.fromhex("0101010101010101010101010101010101010101010101010101010101010101") | |
N = b"ABCDEFGHIJKLMNOPQRSTUVWX" | |
PT = b"XAES-256-GCM" | |
AAD = b"" | |
c = XAES256GCM(K) | |
ct = c.encrypt(N, PT, AAD) | |
assert(ct == bytes.fromhex("ce546ef63c9cc60765923609b33a9a1974e96e52daf2fcf7075e2271")) | |
roundtrip = c.decrypt(N, ct, AAD) | |
assert(roundtrip == PT) | |
K = bytes.fromhex("0303030303030303030303030303030303030303030303030303030303030303") | |
N = b"ABCDEFGHIJKLMNOPQRSTUVWX" | |
PT = b"XAES-256-GCM" | |
AAD = b"c2sp.org/XAES-256-GCM" | |
c = XAES256GCM(K) | |
ct = c.encrypt(N, PT, AAD) | |
assert(ct == bytes.fromhex("986ec1832593df5443a179437fd083bf3fdb41abd740a21f71eb769d")) | |
roundtrip = c.decrypt(N, ct, AAD) | |
assert(roundtrip == PT) | |
# randomized test setup | |
# https://github.com/DavidBuchanan314/ml-kem-stuff/blob/b44c93e048d768f9b96bd82e24ef8d845944a777/shakestream.py#L2 | |
class ShakeStream: | |
def __init__(self, digestfn) -> None: | |
# digestfn is anything we can call repeatedly with different lengths | |
self.digest = digestfn | |
self.buf = self.digest(32) # arbitrary starting length | |
self.offset = 0 | |
def read(self, n: int) -> bytes: | |
# double the buffer size until we have enough | |
while self.offset + n > len(self.buf): | |
self.buf = self.digest(len(self.buf) * 2) | |
res = self.buf[self.offset:self.offset + n] | |
self.offset += n | |
return res | |
from hashlib import shake_128 | |
rng = ShakeStream(shake_128(b"").digest) | |
accumulator = shake_128() | |
for _ in range(10_000): # The 1M testcase is painfully slow in Python, but I did run it once and it passed | |
key = rng.read(XAES256GCM.KEY_SIZE) | |
nonce = rng.read(XAES256GCM.NONCE_SIZE) | |
plaintext = rng.read(rng.read(1)[0]) | |
aad = rng.read(rng.read(1)[0]) | |
ciphertext = XAES256GCM(key).encrypt(nonce, plaintext, aad) | |
accumulator.update(ciphertext) | |
roundtrip = XAES256GCM(key).decrypt(nonce, ciphertext, aad) | |
assert(roundtrip == plaintext) | |
assert(accumulator.digest(32) == bytes.fromhex("e6b9edf2df6cec60c8cbd864e2211b597fb69a529160cd040d56c0c210081939")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment