Skip to content

Instantly share code, notes, and snippets.

@DavidBuchanan314
Created January 5, 2024 20:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DavidBuchanan314/9e300875341f7290cae378397fc72370 to your computer and use it in GitHub Desktop.
Save DavidBuchanan314/9e300875341f7290cae378397fc72370 to your computer and use it in GitHub Desktop.
"""
This pure-python ChaCha20 implementation reaches 32MiB/sec on my machine (M1 Pro)
otoh, cryptography.io's impl reaches about 1700MiB/s. Way faster, of course, but only about 50x faster.
This is code is a proof-of-concept and should not be used in a security context.
"""
CONST_MAGIC = b"expand 32-byte k"
CONST_WORDS = [int.from_bytes(CONST_MAGIC[i:i+4], "little") for i in range(0, 16, 4)]
def chacha20_wide_qr(state, addmask, rotmasks, a, b, c, d):
# I thiiiink I can get away without masking the addition results, since
# we have 8 bits to spare, and masking eventually happens during rotations.
# It appears to work, but I haven't proven it to be correct :P
#state[a] = (state[a] + state[b]) & addmask
state[a] += state[b]
state[d] ^= state[a]
state[d] = ((state[d] << 16) & rotmasks[0][0]) | ((state[d] >> (32 - 16)) & rotmasks[0][1])
#state[c] = (state[c] + state[d]) & addmask
state[c] += state[d]
state[b] ^= state[c]
state[b] = ((state[b] << 12) & rotmasks[1][0]) | ((state[b] >> (32 - 12)) & rotmasks[1][1])
#state[a] = (state[a] + state[b]) & addmask
state[a] += state[b]
state[d] ^= state[a]
state[d] = ((state[d] << 8) & rotmasks[2][0]) | ((state[d] >> (32 - 8)) & rotmasks[2][1])
#state[c] = (state[c] + state[d]) & addmask
state[c] += state[d]
state[b] ^= state[c]
state[b] = ((state[b] << 7) & rotmasks[3][0]) | ((state[b] >> (32 - 7)) & rotmasks[3][1])
def chacha20_bytes(ctr: int, nonce: bytes, key: bytes, msg: bytes) -> bytes:
#start = time.time()
n_blocks = (len(msg) + 63) // 64
if ctr + n_blocks > 2**32:
raise ValueError("this impl only supports a 32-bit ctr, for now")
nonce_words = [int.from_bytes(nonce[i:i+4], "little") for i in range(0, 8, 4)]
key_words = [int.from_bytes(key[i:i+4], "little") for i in range(0, 32, 4)]
ones = int.from_bytes(b"\x01\x00\x00\x00\x00" * n_blocks, "little") # XXX: some stuff could be cached if n_blocks gets reused
addmask = ones * ((1<<32)-1)
rotmasks = [(ones * (((1<<(32-i))-1)<<i), ones * ((1<<i)-1)) for i in [16, 12, 8, 7]]
# there might be some cleverer way to do this?
ctrval = int.from_bytes(b"".join(i.to_bytes(5, "little") for i in range(ctr, ctr + n_blocks)), "little")
state = [
ones * CONST_WORDS[0], ones * CONST_WORDS[1], ones * CONST_WORDS[2], ones * CONST_WORDS[3],
ones * key_words[0], ones * key_words[1], ones * key_words[2], ones * key_words[3],
ones * key_words[4], ones * key_words[5], ones * key_words[6], ones * key_words[7],
ctrval, 0, ones * nonce_words[0], ones * nonce_words[1]
]
#print("packed", time.time()-start)
orig = state[::]
for _ in range(10):
# Odd round
chacha20_wide_qr(state, addmask, rotmasks, 0, 4, 8, 12) # 1st column
chacha20_wide_qr(state, addmask, rotmasks, 1, 5, 9, 13) # 2nd column
chacha20_wide_qr(state, addmask, rotmasks, 2, 6, 10, 14) # 3rd column
chacha20_wide_qr(state, addmask, rotmasks, 3, 7, 11, 15) # 4th column
# Even round
chacha20_wide_qr(state, addmask, rotmasks, 0, 5, 10, 15) # diagonal 1 (main diagonal)
chacha20_wide_qr(state, addmask, rotmasks, 1, 6, 11, 12) # diagonal 2
chacha20_wide_qr(state, addmask, rotmasks, 2, 7, 8, 13) # diagonal 3
chacha20_wide_qr(state, addmask, rotmasks, 3, 4, 9, 14) # diagonal 4
for j in range(16):
state[j] += orig[j]
#print("roundsdone", time.time()-start)
state_bytes = [s.to_bytes(5 * n_blocks, "little") for s in state]
cipherstream = bytearray(64 * n_blocks)
for i in range(16):
for j in range(4):
cipherstream[i*4+j:len(cipherstream):64] = state_bytes[i][j:len(state_bytes[i]):5]
return (int.from_bytes(cipherstream[:len(msg)], "little") ^ int.from_bytes(msg, "little")).to_bytes(len(msg), "little")
if __name__ == "__main__":
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
import time
import os
for msglen in [64, 1024, 1024*32, 1024*512, 1024*1024*64]:
print()
print(f"msg size {msglen} ({msglen/1024}KiB) ({msglen/(1024*1024)}MiB)")
msg = os.urandom(msglen)
start = time.time()
res = chacha20_bytes(0, b"N"*8, b"K"*32, msg)
duration = time.time() - start
MiBsec = msglen/(1024*1024)/duration
print(MiBsec, "MiB/sec (pure python)")
#print(res.hex())
start = time.time()
cc = Cipher(algorithms.ChaCha20(b"K"*32, b"\x00"*8 + b"N"*8), mode=None).encryptor()
res2 = cc.update(msg)
duration = time.time() - start
MiBsec = msglen/(1024*1024)/duration
print(MiBsec, "MiB/sec (cryptography.io)")
#print(res2.hex())
assert(res2 == res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment