Skip to content

Instantly share code, notes, and snippets.

@r98inver
Last active December 11, 2023 12:50
Show Gist options
  • Save r98inver/6a92aac1b3d07b0a1916655ea25951b0 to your computer and use it in GitHub Desktop.
Save r98inver/6a92aac1b3d07b0a1916655ea25951b0 to your computer and use it in GitHub Desktop.
0CTF 2023 - RSA e recovery with dlp over smooth primes
from sage.all import *
from pwn import remote, log
import re
import string
from hashlib import sha256
import itertools
import time
from Crypto.Util.number import *
from random import randrange
P_BITS = 512
E_BITS = int(P_BITS * 2 * 0.292) + 30
CNT_MAX = 7
def get_smooth_n():
factors = []
n = 2
while True:
dt = 512 - n.bit_length()
if dt < 39:
break
k = randrange(2**39, 2**40)
factors.append(k)
n *= k
while True:
k = randrange(2**dt, 2**(dt+2))
if (n*k).bit_length() == 512:
return n*k
def get_smooth_prime():
while True:
n = get_smooth_n()
assert n%2 == 0
if isPrime(n+1):
return n+1
class myLCG:
def __init__(self, p, a, b, s):
self.p = p
self.a = a
self.b = b
self.s = s
def next(self):
out = self.s[0]
self.s = self.s[1: ] + [(sum([i * j for (i, j) in zip(self.a, self.s)]) + self.b) % self.p]
return out
def recover_e(p, q, pt, ct):
return int(discrete_log(Mod(ct, p*q), Mod(pt, p*q)))
def recover_bob_e(p, q, pt, ct, lcg1, bob=None):
guess_e = recover_e(p, q, pt, ct)
l_res_1 = lcg1.next()
l_res_2 = lcg1.next()
M = Zmod(p*q)(pt).multiplicative_order()
guess_k = round((l_res_2 - guess_e) / M)
guess_bob = (guess_e + guess_k * M)
if bob:
bob_used = bob.e ^ l_res_2 # DEBUG
assert guess_bob == bob_used, f'Missed bob used by {(guess_bob - bob_used)/M}'
# ritorna l'effettiva e di bob in questo momento
return guess_bob ^ l_res_2, l_res_1, l_res_2
def dotti():
# Run the actual solution
r = remote('eu.chall.ctf.0ops.sjtu.cn', 32226)
l = r.recvline().decode().strip() #sha256(XXXX + DgWRlfIyTStdy2hf) == 43cfc8d954ff5bbde94432fe97f48eb862d8fad1604b46dc729c535faab0c78f
tic = time.time()
known = l.split('+')[1].split(')')[0].strip()
target = l.split('==')[1].strip()
for l4 in itertools.product(string.ascii_letters + string.digits, repeat=4):
s = ''.join(l4) + known
pw = sha256(s.encode()).hexdigest()
if pw == target:
log.info(f'Pow done in {round(time.time() - tic, 2)}s')
break
r.sendline(''.join(l4).encode())
# Generate weak bob
# p, q = get_smooth_prime()
p = 13336845925953361803357594752044388161869287904361366396293257057769121911528931798834188209225684333529866846558969839994804929306174382249410660761600001
q = 12405668035357217661840880178150096566393776028224126375530501642567756061506053596805973121761525579531821248649872698336698685380755700773813485568000001
# Add to sage factor
pari.addprimes(p)
pari.addprimes(q)
r.recvuntil(b'Give me your RSA key plz.')
r.sendline(f'{p:0128x}'.encode())
r.sendline(f'{q:0128x}'.encode())
assert r.recvline().decode().strip() == ''
alice_e = int(r.recvline().decode().strip())
alice_n = int(r.recvline().decode().strip())
lcg_p = int(r.recvline().decode().strip())
lcg_a = eval(r.recvline().decode().strip())
lcg_b = int(r.recvline().decode().strip())
lcg_s = eval(r.recvline().decode().strip())
lcg1 = myLCG(lcg_p, lcg_a, lcg_b, lcg_s)
# First round
msg = 0x69
# Alice encryption
alice_e_noise = (alice_e ^ lcg1.next()) % (2**E_BITS)
msg_a = pow(msg, alice_e_noise ^ lcg1.next(), alice_n)
# Get bob encryption
r.recvuntil(b'pt:')
r.sendline(f'{msg:0256x}'.encode())
r.recvuntil(b'ct:')
ct = int(r.recvline().strip(), 16)
bob_e, _l1, _l2 = recover_bob_e(p, q, msg_a, ct, lcg1)
# Move forward bob
bob_e = (bob_e ^ lcg1.next()) % (2**E_BITS)
bob_e = bob_e ^ lcg1.next()
lb = carmichael_lambda(p*q)
if gcd(lb, bob_e) != 1:
log.warning(f'Fail: non-invertible bob (gcd {gcd(lb, bob_e)})')
r.close()
return -1
bob_d = inverse_mod(bob_e, lb)
# Exit the loop and get the secret
r.recvuntil(b'pt:')
r.sendline(b'0')
r.recvuntil(b'secrets_ct:')
secrets_ct = int(r.recvline().decode().strip(), 16)
alice_encoded_secret = pow(secrets_ct, bob_d, p*q)
log.info(f'{alice_encoded_secret = }')
# Refresh the challenge
lcg_p = int(r.recvline().decode().strip())
lcg_a = eval(r.recvline().decode().strip())
lcg_b = int(r.recvline().decode().strip())
lcg_s = eval(r.recvline().decode().strip())
lcg1 = myLCG(lcg_p, lcg_a, lcg_b, lcg_s)
r.recvuntil(b'ct:')
r.sendline(f'{alice_encoded_secret:0256x}'.encode())
r.recvuntil(b'pt:')
pt1 = int(r.recvline().decode().strip(), 16) # Bob-encoded alice-decoded secret
# Shift bob
res_1 = lcg1.next() # skip
res_2 = lcg1.next() # used to encode the secret
new_msg = pow(0x69, alice_e, alice_n)
r.recvuntil(b'ct:')
r.sendline(f'{new_msg:0256x}'.encode())
r.recvuntil(b'pt:')
pt2 = int(r.recvline().decode().strip(), 16) # Bob-encoded alice-decoded secret
bob_e, l_res_1, l_res_2 = recover_bob_e(p, q, 0x69, pt2, lcg1)
bob_old_e = (int(bob_e) ^ int(l_res_1)) % (2**E_BITS)
bob_used_old = bob_old_e ^ res_2
lb = carmichael_lambda(p*q)
if gcd(lb, bob_used_old) != 1:
log.warning(f'Fail: non-invertible second bob (gcd {gcd(lb, bob_used_old)})')
r.close()
return -1
bob_d = inverse_mod(bob_used_old, lb)
guess_secret_final = pow(pt1, bob_d, p*q)
r.recvuntil(b'ct:')
r.sendline(b'0')
r.sendline(hex(guess_secret_final).encode())
r.interactive()
r.close()
return 0
if __name__ == '__main__':
while True:
dtt = dotti()
if dtt == 0:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment