Skip to content

Instantly share code, notes, and snippets.

@nneonneo
Created May 1, 2017 00:16
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save nneonneo/367240ae2d8e705bb9173a49a7c8b0cd to your computer and use it in GitHub Desktop.
Save nneonneo/367240ae2d8e705bb9173a49a7c8b0cd to your computer and use it in GitHub Desktop.
Godzilla solver at DEF CON CTF Quals 2017
from Crypto.Util.number import inverse
import random
import multiprocessing
import signal
def mean(x):
return sum(x) / float(len(x))
class Montgomery:
def __init__(self, nbits, n):
self.nbits = nbits
self.n = n
self.r = 1 << nbits
assert self.r > self.n
self.rmask = self.r - 1
self.rinv = inverse(self.r, self.n)
self.factor = (self.r * self.rinv - 1) // self.n
self.m_1 = self.r % n
def montmul(self, m_a, m_b):
prod = m_a * m_b
tmp = ((prod & self.rmask) * self.factor) & self.rmask
out = (prod + tmp * self.n) >> self.nbits
if out >= self.n:
return out - self.n, 1
else:
return out, 0
def modpow(self, a, d):
m_a = self.to_mont(a)
m_x = self.m_1
ct = 0
for i in reversed(xrange(self.nbits)):
m_x, c = self.montmul(m_x, m_x)
ct += c
if d & (1 << i):
m_x, c = self.montmul(m_x, m_a)
ct += c
return (m_x * self.rinv) % self.n, ct
def to_mont(self, a):
return (a << self.nbits) % self.n
def classify(mont, m_a, pcache, known_d):
# keep track of how many reductions we've done so far
# this reduces variance in the measurements
dind, curctr, m_x = pcache
for bit in known_d[dind:]:
m_x, c = mont.montmul(m_x, m_x)
curctr += c
if bit:
m_x, c = mont.montmul(m_x, m_a)
curctr += c
pcache = (len(known_d), curctr, m_x)
m_x, _ = mont.montmul(m_x, m_x)
# at this point, the next multiplication may or may not happen
m_y, oc = mont.montmul(m_x, m_a)
_, o1c = mont.montmul(m_y, m_y)
_, o2c = mont.montmul(m_x, m_x)
return pcache, curctr, oc, o1c, o2c
def worker_function(tid, nbits, n, data, qin, qout):
signal.signal(signal.SIGINT, signal.SIG_IGN)
mont = Montgomery(nbits, n)
wlist = [(a, t, mont.to_mont(a)) for (a,t) in data]
powcache = [(0, 0, mont.m_1) for (a,t) in data]
while 1:
known_d = qin.get()
if known_d is None:
print "tid %d exiting" % tid
qout.put((tid, None))
return
out = []
for i, (a, t, m_a) in enumerate(wlist):
res = classify(mont, m_a, powcache[i], known_d)
powcache[i] = res[0]
out.append((t, res[1:]))
qout.put((tid, out))
class MontgomeryHacker:
def __init__(self, nbits, n, e, data, ncpu=None):
if ncpu is None:
ncpu = multiprocessing.cpu_count()
self.timings = {d[0]: d[1] for d in data}
self.workers = []
self.qout = multiprocessing.Queue()
csize = (len(data) + ncpu - 1) // ncpu
for i in xrange(ncpu):
qin = multiprocessing.Queue()
chunk = data[i*csize:(i+1)*csize]
proc = multiprocessing.Process(target=worker_function, args=(i, nbits, n, chunk, qin, self.qout))
proc.start()
self.workers.append((proc, qin))
self.n = n
self.e = e
self.nbits = nbits
def check_d(self, d):
# bruteforce last digit
for last in (0, 1):
dd = int(''.join(map(str, d + [last])), 2)
if pow(pow(2, self.e, self.n), dd, self.n) == 2:
d.append(last)
return True
return False
def guess_d(self, known_d=None):
if known_d is None:
known_d = [1]
try:
while not self.check_d(known_d):
if len(known_d) > self.nbits:
return None
mc = [[], []]
m1c = [[], []]
m2c = [[], []]
for proc, qin in self.workers:
qin.put(tuple(known_d))
for _ in self.workers:
tid, res = self.qout.get()
for t, (curctr, oc, o1c, o2c) in res:
mc[oc].append(t-curctr)
m1c[o1c].append(t-curctr)
m2c[o2c].append(t-curctr)
mc0, mc1 = mean(mc[0]), mean(mc[1])
m1c0, m1c1 = mean(m1c[0]), mean(m1c[1])
m2c0, m2c1 = mean(m2c[0]), mean(m2c[1])
print len(known_d), "mult attack:", mc1 - mc0
print len(known_d), "square attack:", m1c1 - m1c0, m2c1 - m2c0
print len(known_d), "square delta:", abs((m1c1 - m1c0) - (m2c1 - m2c0))
if m1c1 - m1c0 > m2c1 - m2c0:
print len(known_d), "guess 1"
known_d.append(1)
else:
print len(known_d), "guess 0"
known_d.append(0)
print ''.join(map(str, known_d))
finally:
for proc, qin in self.workers:
qin.put(None)
count = len(self.workers)
while count > 0:
tid, res = self.qout.get()
if res is not None:
continue
count -= 1
self.workers[tid][0].join()
return int(''.join(map(str, known_d)), 2)
def read_csv(n, fn):
print "loading", fn
out = []
for i, row in enumerate(open(fn, 'r')):
r = row.split(',')
if len(r) != 4:
print "skip row %d: bad field count" % (i+1)
continue
a, ts1, cnt, ts2 = r
a = int(a)
if a >= n:
print "skip row %d: bad n" % (i+1)
continue
cnt = int(cnt)
if cnt % 200 != 0:
print "skip row %d: bad cnt" % (i+1)
continue
if int(ts1) >= (1<<32) or int(ts2) >= (1<<32):
print "skip row %d: bad ts" % (i+1)
continue
out.append((a, cnt // 200))
return out
def hack_real():
n = 1003103838556651507628555636330026033778617920156717988356542246694938166737814566792763905093451568623751209393228473104621241127455927948500155303095577513801000908445368656518814002954652859078574695890342113223231421454500402449
e = 0x10001
nbits = 768
data = []
import os
for fn in os.listdir('data/real'):
if fn.endswith('.csv'):
data += read_csv(n, 'data/real/' + fn)
print "got %d rows" % len(data)
m = MontgomeryHacker(nbits, n, e, data)
print m.guess_d()
hack_real()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment