-
-
Save nneonneo/367240ae2d8e705bb9173a49a7c8b0cd to your computer and use it in GitHub Desktop.
Godzilla solver at DEF CON CTF Quals 2017
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from Crypto.Util.number import inverse | |
import random | |
import multiprocessing | |
import signal | |
def mean(x): | |
return sum(x) / float(len(x)) | |
class Montgomery: | |
def __init__(self, nbits, n): | |
self.nbits = nbits | |
self.n = n | |
self.r = 1 << nbits | |
assert self.r > self.n | |
self.rmask = self.r - 1 | |
self.rinv = inverse(self.r, self.n) | |
self.factor = (self.r * self.rinv - 1) // self.n | |
self.m_1 = self.r % n | |
def montmul(self, m_a, m_b): | |
prod = m_a * m_b | |
tmp = ((prod & self.rmask) * self.factor) & self.rmask | |
out = (prod + tmp * self.n) >> self.nbits | |
if out >= self.n: | |
return out - self.n, 1 | |
else: | |
return out, 0 | |
def modpow(self, a, d): | |
m_a = self.to_mont(a) | |
m_x = self.m_1 | |
ct = 0 | |
for i in reversed(xrange(self.nbits)): | |
m_x, c = self.montmul(m_x, m_x) | |
ct += c | |
if d & (1 << i): | |
m_x, c = self.montmul(m_x, m_a) | |
ct += c | |
return (m_x * self.rinv) % self.n, ct | |
def to_mont(self, a): | |
return (a << self.nbits) % self.n | |
def classify(mont, m_a, pcache, known_d): | |
# keep track of how many reductions we've done so far | |
# this reduces variance in the measurements | |
dind, curctr, m_x = pcache | |
for bit in known_d[dind:]: | |
m_x, c = mont.montmul(m_x, m_x) | |
curctr += c | |
if bit: | |
m_x, c = mont.montmul(m_x, m_a) | |
curctr += c | |
pcache = (len(known_d), curctr, m_x) | |
m_x, _ = mont.montmul(m_x, m_x) | |
# at this point, the next multiplication may or may not happen | |
m_y, oc = mont.montmul(m_x, m_a) | |
_, o1c = mont.montmul(m_y, m_y) | |
_, o2c = mont.montmul(m_x, m_x) | |
return pcache, curctr, oc, o1c, o2c | |
def worker_function(tid, nbits, n, data, qin, qout): | |
signal.signal(signal.SIGINT, signal.SIG_IGN) | |
mont = Montgomery(nbits, n) | |
wlist = [(a, t, mont.to_mont(a)) for (a,t) in data] | |
powcache = [(0, 0, mont.m_1) for (a,t) in data] | |
while 1: | |
known_d = qin.get() | |
if known_d is None: | |
print "tid %d exiting" % tid | |
qout.put((tid, None)) | |
return | |
out = [] | |
for i, (a, t, m_a) in enumerate(wlist): | |
res = classify(mont, m_a, powcache[i], known_d) | |
powcache[i] = res[0] | |
out.append((t, res[1:])) | |
qout.put((tid, out)) | |
class MontgomeryHacker: | |
def __init__(self, nbits, n, e, data, ncpu=None): | |
if ncpu is None: | |
ncpu = multiprocessing.cpu_count() | |
self.timings = {d[0]: d[1] for d in data} | |
self.workers = [] | |
self.qout = multiprocessing.Queue() | |
csize = (len(data) + ncpu - 1) // ncpu | |
for i in xrange(ncpu): | |
qin = multiprocessing.Queue() | |
chunk = data[i*csize:(i+1)*csize] | |
proc = multiprocessing.Process(target=worker_function, args=(i, nbits, n, chunk, qin, self.qout)) | |
proc.start() | |
self.workers.append((proc, qin)) | |
self.n = n | |
self.e = e | |
self.nbits = nbits | |
def check_d(self, d): | |
# bruteforce last digit | |
for last in (0, 1): | |
dd = int(''.join(map(str, d + [last])), 2) | |
if pow(pow(2, self.e, self.n), dd, self.n) == 2: | |
d.append(last) | |
return True | |
return False | |
def guess_d(self, known_d=None): | |
if known_d is None: | |
known_d = [1] | |
try: | |
while not self.check_d(known_d): | |
if len(known_d) > self.nbits: | |
return None | |
mc = [[], []] | |
m1c = [[], []] | |
m2c = [[], []] | |
for proc, qin in self.workers: | |
qin.put(tuple(known_d)) | |
for _ in self.workers: | |
tid, res = self.qout.get() | |
for t, (curctr, oc, o1c, o2c) in res: | |
mc[oc].append(t-curctr) | |
m1c[o1c].append(t-curctr) | |
m2c[o2c].append(t-curctr) | |
mc0, mc1 = mean(mc[0]), mean(mc[1]) | |
m1c0, m1c1 = mean(m1c[0]), mean(m1c[1]) | |
m2c0, m2c1 = mean(m2c[0]), mean(m2c[1]) | |
print len(known_d), "mult attack:", mc1 - mc0 | |
print len(known_d), "square attack:", m1c1 - m1c0, m2c1 - m2c0 | |
print len(known_d), "square delta:", abs((m1c1 - m1c0) - (m2c1 - m2c0)) | |
if m1c1 - m1c0 > m2c1 - m2c0: | |
print len(known_d), "guess 1" | |
known_d.append(1) | |
else: | |
print len(known_d), "guess 0" | |
known_d.append(0) | |
print ''.join(map(str, known_d)) | |
finally: | |
for proc, qin in self.workers: | |
qin.put(None) | |
count = len(self.workers) | |
while count > 0: | |
tid, res = self.qout.get() | |
if res is not None: | |
continue | |
count -= 1 | |
self.workers[tid][0].join() | |
return int(''.join(map(str, known_d)), 2) | |
def read_csv(n, fn): | |
print "loading", fn | |
out = [] | |
for i, row in enumerate(open(fn, 'r')): | |
r = row.split(',') | |
if len(r) != 4: | |
print "skip row %d: bad field count" % (i+1) | |
continue | |
a, ts1, cnt, ts2 = r | |
a = int(a) | |
if a >= n: | |
print "skip row %d: bad n" % (i+1) | |
continue | |
cnt = int(cnt) | |
if cnt % 200 != 0: | |
print "skip row %d: bad cnt" % (i+1) | |
continue | |
if int(ts1) >= (1<<32) or int(ts2) >= (1<<32): | |
print "skip row %d: bad ts" % (i+1) | |
continue | |
out.append((a, cnt // 200)) | |
return out | |
def hack_real(): | |
n = 1003103838556651507628555636330026033778617920156717988356542246694938166737814566792763905093451568623751209393228473104621241127455927948500155303095577513801000908445368656518814002954652859078574695890342113223231421454500402449 | |
e = 0x10001 | |
nbits = 768 | |
data = [] | |
import os | |
for fn in os.listdir('data/real'): | |
if fn.endswith('.csv'): | |
data += read_csv(n, 'data/real/' + fn) | |
print "got %d rows" % len(data) | |
m = MontgomeryHacker(nbits, n, e, data) | |
print m.guess_d() | |
hack_real() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment