-
-
Save hellman/524ae5cb00c6c80b68bb7458ccadfa4d to your computer and use it in GitHub Desktop.
Balsn CTF 2019 - unpredictable
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import hashlib | |
import random | |
version = sys.version.replace('\n', ' ') | |
print(f'Python {version}') | |
random.seed(os.urandom(1337)) | |
for i in range(0x1337): | |
print(random.randrange(3133731337)) | |
# Encrypt flag | |
sha512 = hashlib.sha512() | |
for _ in range(1000): | |
rnd = random.getrandbits(32) | |
sha512.update(str(rnd).encode('ascii')) | |
key = sha512.digest() | |
with open('../flag.txt', 'rb') as f: | |
flag = f.read() | |
enc = bytes(a ^ b for a, b in zip(flag, key)) | |
print('Encrypted:', enc.hex()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#-*- coding:utf-8 -*- | |
''' | |
The idea is to find missing outputs (larger than 3133731337) that would satisfy | |
at least two relations of the MT generation scheme. | |
When no such numbers could be found, add new number based only on one relation, | |
but only when the position is very likely defined (i.e. between two numbers with known relations). | |
$ py3 wu.py 3 0 | |
... ~1 hour later | |
$ py3 wu.py 3 1 | |
GO ID 3 LEVEL 1 | |
chk 0 | |
match! 4506 | |
decrypted: Balsn{T0_4cCept_0r_tO_r3jeCt__7hat_1s_tHe_Qu3sT1On} | |
''' | |
import random, sys | |
# https://github.com/kmyk/mersenne-twister-predictor | |
from mtpred import MT19937Predictor | |
from mtpred import MATRIX_A, UPPER_MASK, LOWER_MASK, mag01, tempering, untempering, N, M | |
MOD = 3133731337 | |
ID = int(sys.argv[1]) | |
LEVEL = int(sys.argv[2]) | |
f0 = open("output.%d.txt" % ID).read().splitlines() | |
if LEVEL == 0: | |
f = open("output.%d.txt" % ID).read().splitlines() | |
else: | |
f = open("output.%d.txt_level%d" % (ID, LEVEL)).read().splitlines() | |
print("GO ID", ID, "LEVEL", LEVEL) | |
nums0 = map(int, f0[1:-1]) | |
nums0 = list(nums0) | |
nums = map(int, f[1:-1]) | |
nums = map(untempering, nums) | |
nums = list(nums) | |
rev = {v: i for i, v in enumerate(nums)} | |
RM = range(230, M+1) | |
RN = range(400, N+1) | |
sec = None | |
if ID == 4: | |
sec = [int(v, 16) for v in open("secret.txt").read().split()] | |
sec = list(map(untempering, sec)) | |
skips = map(int, open("skips.txt").read().split()) | |
if LEVEL >= 1: | |
numset = set(tempering(v) for v in nums) | |
for i in range(0, len(nums)-N): | |
if i % 1000 == 0: | |
print("chk", i) | |
predictor = MT19937Predictor() | |
for j in range(624): | |
predictor.setrandbits(tempering(nums[i+j]), 32) | |
test = [predictor.getrandbits(32) for _ in range(700)] | |
# print(i, off, len(test & numset)) | |
if len(set(test) & numset) >= 700: | |
for xi in range(0x1337 * 3 // 2): | |
if predictor.getrandbits(32) == nums0[-1]: | |
import hashlib | |
print("match!", xi) | |
sha512 = hashlib.sha512() | |
for _ in range(1000): | |
rnd = predictor.getrandbits(32) | |
sha512.update(str(rnd).encode('ascii')) | |
key = sha512.digest() | |
enc = bytes.fromhex("0630543474e31b91c2b048ba3bf20226f0f55cec646d306233fb0fcc13a1dd6b3320cec24ec7571259fc70e29ebd405cdd61d6e8") | |
pt = bytes(a ^ b for a, b in zip(enc, key)) | |
print('decrypted:', pt.decode("ascii")) | |
# Balsn{T0_4cCept_0r_tO_r3jeCt__7hat_1s_tHe_Qu3sT1On} | |
quit() | |
quit() | |
# different techniques | |
# lower prio more reliable | |
# when don't produce new values, apply less reliable ones | |
PRIO = 0 | |
prevnums = len(nums) | |
for itr in range(10000): | |
if sec: | |
extra = set(nums) - set(sec) | |
if extra: | |
print("fail :(") | |
print(["%08x" % v for v in extra]) | |
quit() | |
ms = [] | |
for i, out in enumerate(nums[:-1]): | |
v1 = nums[i] & LOWER_MASK | |
v1 = (v1 >> 1) ^ mag01[v1 & 0x1] | |
for off_m in RM: | |
if i + off_m >= len(nums): break | |
v2 = nums[i + off_m] | |
v3 = v2 ^ v1 | |
for flag in range(2): | |
if flag: | |
v3 ^= 1 << 30 | |
if v3 not in rev: continue | |
off_n = rev[v3] - i | |
if off_n not in RN: continue | |
assert off_n in RN | |
assert off_m in RM | |
ms.append((i, i+off_m, i+off_n, flag)) | |
ms.sort() | |
print("nums", len(nums), "matches", len(ms), "ratio", len(ms)/len(nums), "sub", len(nums)-len(ms)) | |
print("PRIO", PRIO) | |
if PRIO == 0: | |
PRIO = 2 | |
used1 = set() | |
used2 = set() | |
used3 = set() | |
for i, j, k, flag in ms: | |
used1.add(i) | |
used2.add(j) | |
used3.add(k) | |
def interok(i, j, k): | |
l = 0 | |
r = len(ms)-1 | |
last = l | |
while l < r: | |
mid = (l + r) // 2 | |
ii,jj,kk,_ = ms[mid] | |
if ii < i: | |
last = mid | |
l = mid + 1 | |
else: | |
r = mid - 1 | |
m1 = last | |
l = 0 | |
r = len(ms)-1 | |
last = r | |
while l < r: | |
mid = (l + r) // 2 | |
ii,jj,kk,_ = ms[mid] | |
if ii > i: | |
last = mid | |
r = mid - 1 | |
else: | |
l = mid + 1 | |
m2 = last | |
i1, j1, k1, _ = ms[m1] | |
i2, j2, k2, _ = ms[m2] | |
if not (i1 <= i <= i2): return False | |
if not (j1 <= j <= j2): return False | |
if not (k1 <= k <= k2): return False | |
return True | |
def check_useful(i_insert, val, as1=True, as2=True, as3=True, debug=False): | |
usefulcnt = 0 | |
if PRIO >= 0: MINUSE = 3 | |
if PRIO >= 1: MINUSE = 2 | |
if PRIO >= 2: MINUSE = 1 | |
if debug: print() | |
# check as start | |
if as1: | |
i1 = i_insert | |
for flag in range(2): | |
v1 = val | |
v1 = v1 & LOWER_MASK | |
v1 |= flag << 31 | |
v1 = (v1 >> 1) ^ mag01[v1 & 0x1] | |
for off_m in RM: | |
i2 = i1 + off_m | |
if not (0 <= i2 < len(nums)): break | |
v2 = nums[i2] | |
v3 = v1 ^ v2 | |
if v3 not in rev: continue | |
off_n = rev[v3] - i1 | |
if off_n not in RN: continue | |
if (i1+off_m) in used2: continue | |
if (i1+off_n) in used3: continue | |
if not interok(i1, i1+off_m, i1+off_n): continue | |
if debug: print("us1", i1, i1+off_m, i1+off_n, "deltas", off_m, off_n) | |
usefulcnt += 1 | |
if usefulcnt >= MINUSE: return True | |
# check as mid | |
if as2: | |
i2 = i_insert | |
for off_m in RM: | |
i1 = i2 - off_m | |
if not (0 <= i1 < len(nums)): break | |
for flag in range(2): | |
v1 = nums[i1] | |
v1 = v1 & LOWER_MASK | |
v1 |= flag << 31 | |
v1 = (v1 >> 1) ^ mag01[v1 & 0x1] | |
v2 = val | |
v3 = v2 ^ v1 | |
if v3 not in rev: continue | |
off_n = rev[v3] - i1 | |
if off_n not in RN: continue | |
if (i1) in used1: continue | |
if (i1+off_n) in used3: continue | |
if not interok(i1, i1+off_m, i1+off_n): continue | |
if debug: print("us2", i1, i1+off_m, i1+off_n, "deltas", off_m, off_n) | |
usefulcnt += 1 | |
if usefulcnt >= MINUSE: return True | |
# check as end | |
if as3: | |
i3 = i_insert | |
for off_n in RN: | |
i1 = i3 - off_n | |
if not (0 <= i1 < len(nums)): break | |
for flag in range(2): | |
v1 = nums[i1] & LOWER_MASK | |
v1 |= flag << 31 | |
v1 = (v1 >> 1) ^ mag01[v1 & 0x1] | |
v3 = val | |
v2 = v3 ^ v1 | |
if v2 not in rev: continue | |
off_m = rev[v2] - i1 | |
if off_m not in RM: continue | |
if (i1) in used1: continue | |
if (i1+off_m) in used2: continue | |
if not interok(i1, i1+off_m, i1+off_n): continue | |
if debug: print("us3", i1, i1+off_m, i1+off_n, "deltas", off_m, off_n) | |
usefulcnt += 1 | |
if usefulcnt >= MINUSE: return True | |
nforce = 1 | |
found = 0 | |
mininsert = len(nums)+N+10 | |
for mat1, mat2 in reversed(list(zip(ms, ms[1:]))): | |
if found: | |
found = 0 | |
continue | |
i1, m1, n1, flag1 = mat1 | |
i2, m2, n2, flag2 = mat2 | |
assert i1 <= i2, (mat1, mat2) | |
assert m1 <= m2, (mat1, mat2) | |
assert n1 <= n2, (mat1, mat2) | |
if i1 >= mininsert-N-10: continue | |
if i2 >= mininsert-N-10: continue | |
if m1 >= mininsert-N-10: continue | |
if m2 >= mininsert-N-10: continue | |
if n1 >= mininsert-N-10: continue | |
if n2 >= mininsert-N-10: continue | |
di = i2 - i1 | |
dm = m2 - m1 | |
dn = n2 - n1 | |
if not (di == dm == dn == 1): | |
pass | |
if PRIO >= 3 and di == 2 and dm == 2 and dn == 1: | |
flag = nums[i1] >> 31 | |
i = i1 + 1 | |
j = m1 + 1 | |
v1 = nums[i] & LOWER_MASK | |
v1 |= flag << 31 | |
v1 = (v1 >> 1) ^ mag01[v1 & 0x1] | |
v2 = nums[j] | |
v3 = v1 ^ v2 | |
k = i_insert = n2 | |
if tempering(v3) < MOD: continue | |
if (j - i) not in RM: continue | |
if (k - i) not in RN: continue | |
print(mat1) | |
print(mat2) | |
print("force insert k", i, j, k, ":", "%08x" % v3) | |
# PRIO = 0 | |
nforce -= 1 | |
if nforce == 0: | |
PRIO = 0 | |
for v in nums[i_insert:]: | |
del rev[v] | |
nums.insert(i_insert, v3) | |
mininsert = min(mininsert, i_insert) | |
for i in range(i_insert, len(nums)): | |
assert nums[i] not in rev | |
rev[nums[i]] = i | |
continue | |
if PRIO >= 3 and di == 2 and dm == 1 and dn == 2: | |
flag = nums[i1] >> 31 | |
i = i1 + 1 | |
k = n1 + 1 | |
v1 = nums[i] & LOWER_MASK | |
v1 |= flag << 31 | |
v1 = (v1 >> 1) ^ mag01[v1 & 0x1] | |
v3 = nums[k] | |
v2 = v1 ^ v3 | |
j = i_insert = m2 | |
if tempering(v2) < MOD: continue | |
if (j - i) not in RM: continue | |
if (k - i) not in RN: continue | |
print(mat1) | |
print(mat2) | |
print("force insert j", i, j, k, ":", "%08x" % v2) | |
nforce -= 1 | |
if nforce == 0: | |
PRIO = 0 | |
for v in nums[i_insert:]: | |
del rev[v] | |
nums.insert(i_insert, v2) | |
mininsert = min(mininsert, i_insert) | |
for i in range(i_insert, len(nums)): | |
assert nums[i] not in rev | |
rev[nums[i]] = i | |
continue | |
if PRIO >= 3 and di == 1 and dm == 2 and dn == 2: | |
flag = nums[i1] >> 31 | |
j = m1 + 1 | |
k = n1 + 1 | |
v2 = nums[j] | |
v3 = nums[k] | |
v1 = v2 ^ v3 | |
if v1 >> 31: | |
v1 ^= mag01[1] | |
v1 <<= 1 | |
v1 |= 1 | |
else: | |
v1 <<= 1 | |
if v1 >> 31 != flag: continue | |
i = i_insert = i2 | |
if tempering(v1) < MOD: continue | |
if (j - i) not in RM: continue | |
if (k - i) not in RN: continue | |
print(mat1) | |
print(mat2) | |
print("force insert i", i, j, k, ":", "%08x" % v1) | |
nforce -= 1 | |
if nforce == 0: | |
PRIO = 0 | |
for v in nums[i_insert:]: | |
del rev[v] | |
nums.insert(i_insert, v1) | |
mininsert = min(mininsert, i_insert) | |
for i in range(i_insert, len(nums)): | |
assert nums[i] not in rev | |
rev[nums[i]] = i | |
continue | |
if dn == 1: | |
found = 0 | |
for i in range(i1+1, i2): | |
if found: break | |
for flag in range(2): | |
if found: break | |
v1 = nums[i] & LOWER_MASK | |
v1 |= flag << 31 | |
v1 = (v1 >> 1) ^ mag01[v1 & 0x1] | |
for j in range(m1+1, m2): | |
if found: break | |
# if i == i1 and j == m1: continue | |
# if i == i2 and j == m2: continue | |
v2 = nums[j] | |
v3 = v1 ^ v2 | |
if tempering(v3) < MOD: continue | |
k = i_insert = n2 | |
if (j - i) not in RM: continue | |
if (k - i) not in RN: continue | |
assert v3 not in rev | |
if check_useful(i_insert, v3, as3=False, as2=False): | |
print(mat1) | |
print(mat2) | |
print("useful new k", i, j, k, "%08x (%08x)" % (v3, tempering(v3)), j-i, k-i) | |
for v in nums[i_insert:]: | |
del rev[v] | |
nums.insert(i_insert, v3) | |
mininsert = min(mininsert, i_insert) | |
for i in range(i_insert, len(nums)): | |
if nums[i] in rev: | |
for di in range(-5, 5): | |
print(i + di, "%08x" % nums[i+di]) | |
assert nums[i] not in rev | |
rev[nums[i]] = i | |
found = 1 | |
break | |
if found: continue | |
if di == 1: | |
# print() | |
# print(mat1) | |
# print(mat2) | |
found = 0 | |
for k in range(n1+1, n2): | |
if found: break | |
v3 = nums[k] | |
for j in range(m1+1, m2): | |
# if k == n1 and j == m1: continue | |
# if k == n2 and j == m2: continue | |
if found: break | |
v2 = nums[j] | |
for flag in range(2): | |
if found: break | |
v1 = v3 ^ v2 | |
if v1 >> 31: | |
v1 ^= mag01[1] | |
v1 <<= 1 | |
v1 |= 1 | |
else: | |
v1 <<= 1 | |
v1 ^= flag << 31 | |
if tempering(v1) < MOD: continue | |
i = i_insert = i2 | |
if (j - i) not in RM: continue | |
if (k - i) not in RN: continue | |
assert v1 not in rev | |
if check_useful(i_insert, v1, as1=False): | |
print(mat1) | |
print(mat2) | |
print("useful new i", i, j, k, "%08x" % v1) | |
for v in nums[i_insert:]: | |
del rev[v] | |
nums.insert(i_insert, v1) | |
mininsert = min(mininsert, i_insert) | |
for i in range(i_insert, len(nums)): | |
if nums[i] in rev: | |
for di in range(-5, 5): | |
print(i + di, "%08x" % nums[i+di]) | |
assert nums[i] not in rev, (i, "%08x" % nums[i], rev[nums[i]]) | |
rev[nums[i]] = i | |
found = 1 | |
break | |
if found: continue | |
if dm == 1: | |
found = 0 | |
for k in range(n1+1, n2): | |
if found: break | |
v3 = nums[k] | |
for i in range(i1+1, i2): | |
# if k == n1 or i == i1: continue | |
# if k == n2 or i == i2: continue | |
if found: break | |
for flag in range(2): | |
if found: break | |
v1 = nums[i] | |
v1 = v1 & LOWER_MASK | |
v1 |= flag << 31 | |
v1 = (v1 >> 1) ^ mag01[v1 & 0x1] | |
v2 = v1 ^ v3 | |
if tempering(v2) < MOD: continue | |
j = i_insert = m2 | |
if (j - i) not in RM: continue | |
if (k - i) not in RN: continue | |
assert v2 not in rev | |
if check_useful(i_insert, v2, as2=False, as3=False): | |
print(mat1) | |
print(mat2) | |
print("useful new j", i, j, k, "%08x (%08x)" % (v2, tempering(v2))) | |
for v in nums[i_insert:]: | |
del rev[v] | |
nums.insert(i_insert, v2) | |
mininsert = min(mininsert, i_insert) | |
for i in range(i_insert, len(nums)): | |
if nums[i] in rev: | |
for di in range(-5, 5): | |
print(i + di, "%08x" % nums[i+di]) | |
assert nums[i] not in rev, (i, "%08x" % nums[i], rev[nums[i]]) | |
rev[nums[i]] = i | |
found = 1 | |
break | |
if found: continue | |
print("nums", len(nums)) | |
print() | |
if len(nums) == prevnums: | |
print("nothing new...") | |
if PRIO >= 3: | |
break | |
PRIO += 1 | |
else: | |
PRIO = 0 | |
prevnums = len(nums) | |
ftest = open("output.%d.txt_level%d" % (ID, LEVEL+1), "w") | |
ftest.write("first\n") | |
for v in nums: | |
ftest.write("%s\n" % str(tempering(v))) | |
ftest.write("last") | |
ftest.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment