Skip to content

Instantly share code, notes, and snippets.

@d4em0n
Last active August 31, 2020 07:45
Show Gist options
  • Save d4em0n/22c8e50c504e59d853fe92eb4812b156 to your computer and use it in GitHub Desktop.
Save d4em0n/22c8e50c504e59d853fe92eb4812b156 to your computer and use it in GitHub Desktop.
poly1305 key recovering from pair of 32 bytes message and tag
#!/usr/bin/env python3
import gmpy2
import binascii
key = binascii.unhexlify("85:d6:be:78:57:55:6d:33:7f:44:52:fe:42:d5:06:a8:01:03:80:8a:fb:0d:b2:fd:4a:bf:f6:af:41:49:f5:1b".replace(":", ""))
def clamp(r):
return r & 0x0ffffffc0ffffffc0ffffffc0fffffff
def poly_mac(msg, key):
r = clamp(int.from_bytes(key[:16], 'little'))
s = int.from_bytes(key[16:], 'little')
acc = 0
p = (1<<130)-5
mod = 2**128
for i in range(0,len(msg),16):
block = msg[i:i+16]
habit = 1 << (len(block)*8)
block = habit | int.from_bytes(block, 'little')
acc = ((acc+block)*r)
return (acc + s) % p % mod
def legendre(a, p):
'''Legendre symbol'''
tmp = pow(a, (p-1)//2, p)
return -1 if tmp == p-1 else tmp
def tonelli_shanks(n, p):
'''Find r such that r^2 = n % p, r2 == p-r'''
if legendre(n, p) == -1:
# print('Error: not square root')
return False
s = 0
q = p-1
while q&1 == 0:
s += 1
q >>= 1
if s == 1:
return pow(n, (p+1)//4, p)
z = 1
while legendre(z, p) != -1:
z += 1
c = pow(z, q, p)
r = pow(n, (q+1)//2, p)
t = pow(n, q, p)
m = s
while t != 1:
i = 1
while i < m:
if pow(t, 2**i, p) == 1:
break
i += 1
b = pow(c, 2**(m-i-1), p)
r = (r*b) % p
t = (t * (b**2)) % p
c = pow(b, 2, p)
m = i
return r
m = 2**130-5
mod = 2**128
def solve_quadmod(a, b, c, n):
disc = tonelli_shanks(pow(b,2,n)-4*a*c, n)
if not disc:
return False
x1 = gmpy2.divm(-b + disc, 2*a, n);
x2 = gmpy2.divm(-b - disc, 2*a, n);
return x1, x2
def solve(msg1, tag1, msg2, tag2):
a = int.from_bytes(msg1[:16]+b'\x01', 'little')
d = int.from_bytes(msg2[:16]+b'\x01', 'little')
b = int.from_bytes(msg1[16:32]+b'\x01', 'little')
e = int.from_bytes(msg2[16:32]+b'\x01', 'little')
cs = []
fs = []
tag1 = int.from_bytes(tag1, 'little')
tag2 = int.from_bytes(tag2, 'little')
for i in range(1,4):
cs.append(mod*i + tag1)
fs.append(mod*i + tag2)
for c in cs:
for f in fs:
res = solve_quadmod(d-a, e-b, c-f, m)
if not res:
continue
res = map(int,res)
for r in res:
r = r % mod
y = ((c - a*pow(r, 2, m) - b*r) % m) % mod
key = r.to_bytes(16, 'little')
key += y.to_bytes(16, 'little')
if poly_mac(msg1, key) == tag1 and poly_mac(msg2, key) == tag2:
return key
msg1 = b"B"*32
tag1 = (poly_mac(msg1, key)).to_bytes(16, 'little')
msg2 = b"A"*32
tag2 = (poly_mac(msg2, key)).to_bytes(16, 'little')
msg3 = b"C"*32
tag3 = poly_mac(msg3, key)
key_recovered = solve(msg2, tag2, msg1, tag1)
assert poly_mac(msg3, key_recovered) == tag3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment