Last active March 12, 2021 09:03
Cryptopals Set 8, problem 61 solution

Cryptopals 61 | Duplicate-Signature Key Selection in ECDSA (and RSA)

Here is the solution to problem 61 from The Cryptopals Crypto Challenges. Challenge descriptions from Set 8 have never been officially published. However, a description for problem 61 can be found here:

In short, let's say we have a pair (message, signature) for some digital-signature scheme (DSA, ECDSA, RSA, etc). We need to construct valid scheme parameters and signer's public key, so that the signature can be successfully verified with these parameters.

I've implemented this algorithm for DSA, ECDSA and RSA.

#!/usr/bin/env python3.8
from os import urandom
from math import gcd
from random import getrandbits, randrange
from hashlib import sha256
from collections import namedtuple
from gmpy2 import is_prime, next_prime
def H(x):
hash_ = sha256(x).digest()
return int.from_bytes(hash_, 'big')
def random_prime(bits):
msb = 0b11 << (bits - 2)
rand = getrandbits(bits)
return int(next_prime(msb | rand))
DSAParams = namedtuple('DSAParams', ('q', 'p', 'g'))
DSASignature = namedtuple('DSASignature', ('r', 's'))
DSAPublicKey = namedtuple('DSAPublicKey', ('y'))
DSAPrivateKey = namedtuple('DSAPrivateKey', ('x'))
class DSA:
def params(q_bits, p_bits):
q = random_prime(q_bits)
while True:
t = getrandbits(p_bits - q_bits - 1)
p = 2 * t * q + 1
if is_prime(p) and p.bit_length() == p_bits:
while True:
h = randrange(2, p - 1)
g = pow(h, 2 * t, p)
if pow(g, q, p) == 1:
return DSAParams(q, p, g)
def keypair(params):
x = randrange(1, params.q)
y = pow(params.g, x, params.p)
return DSAPublicKey(y), DSAPrivateKey(x)
def sign(params, private, message):
while True:
while True:
k = randrange(1, params.q)
r = pow(params.g, k, params.p) % params.q
if r != 0:
s = pow(k, -1, params.q) * (H(message) + private.x * r) % params.q
if s != 0:
return DSASignature(r, s)
def verify(params, public, message, signature):
if not 0 < signature.r < params.q or not 0 < signature.s < params.q:
return False
w = pow(signature.s, -1, params.q)
u1 = H(message) * w % params.q
u2 = signature.r * w % params.q
v = pow(params.g, u1, params.p) * pow(public.y, u2, params.p) % params.p % params.q
return v == signature.r
def DSA_duplicate(params, public, message, signature):
w = pow(signature.s, -1, params.q)
u1 = H(message) * w % params.q
u2 = signature.r * w % params.q
z = gcd(u1, u2)
v = pow(params.g, u1 // z, params.p) * pow(public.y, u2 // z, params.p) % params.p
while True:
fake_x = randrange(1, params.q)
t = (u1 + u2 * fake_x) // z
if gcd(t, params.p - 1) == 1:
i = pow(t, -1, params.p - 1)
fake_g = pow(v, i, params.p)
fake_y = pow(fake_g, fake_x, params.p)
fake_params = DSAParams(params.q, params.p, fake_g)
fake_public = DSAPublicKey(fake_y)
return fake_params, fake_public
def main():
q_bits, p_bits = 160, 1024
message_length = 1000
tries = 100
passed = 0
for i in range(tries):
print(f'DSA TEST {i}')
message = urandom(message_length)
params = DSA.params(q_bits, p_bits)
public, private = DSA.keypair(params)
signature = DSA.sign(params, private, message)
if not DSA.verify(params, public, message, signature):
print('sign+verify error!')
fake_params, fake_public = DSA_duplicate(params, public, message, signature)
if DSA.verify(fake_params, fake_public, message, signature):
passed += 1
print(f'PASSED {passed} / {tries} tests')
if __name__ == '__main__':
#!/usr/bin/env python3.8
from os import urandom
from random import randrange
from hashlib import sha256
from collections import namedtuple
from fastecdsa.curve import P256
def H(x):
hash_ = sha256(x).digest()
return int.from_bytes(hash_, 'big')
ECDSAParams = namedtuple('ECDSAParams', ('q', 'P'))
ECDSASignature = namedtuple('ECDSASignature', ('r', 's'))
ECDSAPublicKey = namedtuple('ECDSAPublicKey', ('Q'))
ECDSAPrivateKey = namedtuple('ECDSAPrivateKey', ('x'))
class ECDSA:
def params(curve):
return ECDSAParams(curve.q, curve.G)
def keypair(params):
x = randrange(1, params.q)
Q = x * params.P
return ECDSAPublicKey(Q), ECDSAPrivateKey(x)
def sign(params, private, message):
while True:
while True:
k = randrange(1, params.q)
r = (k * params.P).x
if r != 0:
s = pow(k, -1, params.q) * (H(message) + private.x * r) % params.q
if s != 0:
return ECDSASignature(r, s)
def verify(params, public, message, signature):
if not 0 < signature.r < params.q or not 0 < signature.s < params.q:
return False
w = pow(signature.s, -1, params.q)
u1 = H(message) * w % params.q
u2 = signature.r * w % params.q
v = (u1 * params.P + u2 * public.Q).x
return v == signature.r
def ECDSA_duplicate(params, public, message, signature):
w = pow(signature.s, -1, params.q)
u1 = H(message) * w % params.q
u2 = signature.r * w % params.q
R = u1 * params.P + u2 * public.Q
fake_x = randrange(1, params.q)
t = u1 + u2 * fake_x
fake_P = pow(t, -1, params.q) * R
fake_Q = fake_x * fake_P
fake_params = ECDSAParams(params.q, fake_P)
fake_public = ECDSAPublicKey(fake_Q)
return fake_params, fake_public
def main():
message_length = 1000
curve = P256
tries = 100
passed = 0
for i in range(tries):
print(f'ECDSA TEST {i}')
message = urandom(message_length)
params = ECDSA.params(curve)
public, private = ECDSA.keypair(params)
signature = ECDSA.sign(params, private, message)
if not ECDSA.verify(params, public, message, signature):
print('sign+verify error!')
fake_params, fake_public = ECDSA_duplicate(params, public, message, signature)
if ECDSA.verify(fake_params, fake_public, message, signature):
passed += 1
print(f'PASSED {passed} / {tries} tests')
if __name__ == '__main__':
#!/usr/bin/env python3.8
from os import urandom
from random import getrandbits, shuffle
from hashlib import sha256
from collections import namedtuple
from gmpy2 import is_prime, next_prime, isqrt
def H(x):
hash_ = sha256(x).digest()
return int.from_bytes(hash_, 'big')
def random_prime(bits):
msb = 0b11 << (bits - 2)
rand = getrandbits(bits)
return int(next_prime(msb | rand))
RSASignature = namedtuple('RSASignature', ('s'))
RSAPublicKey = namedtuple('RSAPublicKey', ('N', 'e'))
RSAPrivateKey = namedtuple('RSAPrivateKey', ('N', 'd'))
class RSA:
def keypair(e, bits):
p = random_prime(bits)
q = random_prime(bits)
N = p * q
d = pow(e, -1, (p - 1) * (q - 1))
return RSAPublicKey(N, e), RSAPrivateKey(N, d)
def sign(private, message):
return RSASignature(pow(H(message), private.d, private.N))
def verify(public, message, signature):
return H(message) == pow(signature.s, public.e, public.N)
def cache_continuously(func):
cache = dict()
def inner_func(*args):
if args in cache:
yield from cache[args]
cache[args] = set()
while True:
output = func(*args)
yield output
return inner_func
def gen_smooth(bits, B):
primes = [random_prime(B)]
primes_count = 1000
for _ in range(primes_count):
while True:
q = 2
factors = [2]
for prime in primes:
q *= prime
if q.bit_length() >= bits:
p = q + 1
if is_prime(p):
return p, tuple(factors)
def bsgs(h, g, q, p):
root = isqrt(q) + 1
tmp, table = 1, {}
for i in range(root):
table[tmp] = i
tmp = tmp * g % p
tmp, inv = h, pow(g, -root, p)
for i in range(root):
if tmp in table:
return table[tmp] + i * root
tmp = tmp * inv % p
def pohlig_hellman(h, g, p, factors):
result = 0
for factor in factors:
q = (p - 1) // factor
h_ = pow(h, q, p)
g_ = pow(g, q, p)
x = bsgs(h_, g_, factor, p)
result += q * x * pow(q, -1, factor)
return result % (p - 1)
def RSA_duplicate(public, message, signature):
bits = 1 + public.N.bit_length() // 2
B = 24
h = H(message)
smooth_iter = gen_smooth(bits, B)
while True:
p, p_factors = next(smooth_iter)
q, q_factors = next(smooth_iter)
if len(set(p_factors) & set(q_factors)) > 1:
bad = False
for mod, factors in zip([p, q], [p_factors, q_factors]):
for factor in factors:
order = (mod - 1) // factor
if pow(h, order, mod) == 1 or pow(signature.s, order, mod) == 1:
bad = True
if bad:
if not bad:
ep = pohlig_hellman(h, signature.s, p, p_factors)
eq = pohlig_hellman(h, signature.s, q, q_factors)
t = (eq - ep) // 2 * pow((p - 1) // 2, -1, (q - 1) // 2)
e = (ep + t * (p - 1)) % ((p - 1) * (q - 1))
return RSAPublicKey(p * q, e)
def main():
e, bits = 65537, 512
message_length = 1000
tries = 100
passed = 0
for i in range(tries):
print(f'RSA TEST {i}')
message = urandom(message_length)
public, private = RSA.keypair(e, bits)
signature = RSA.sign(private, message)
if not RSA.verify(public, message, signature):
print('sign+verify error!')
fake_public = RSA_duplicate(public, message, signature)
if RSA.verify(fake_public, message, signature):
passed += 1
print(f'PASSED {passed} / {tries} tests')
if __name__ == '__main__':
