Skip to content

Instantly share code, notes, and snippets.

@hellman
Last active October 9, 2019 13:47
Show Gist options
  • Save hellman/524ae5cb00c6c80b68bb7458ccadfa4d to your computer and use it in GitHub Desktop.
Save hellman/524ae5cb00c6c80b68bb7458ccadfa4d to your computer and use it in GitHub Desktop.
Balsn CTF 2019 - unpredictable
import sys
import os
import hashlib
import random
version = sys.version.replace('\n', ' ')
print(f'Python {version}')
random.seed(os.urandom(1337))
for i in range(0x1337):
print(random.randrange(3133731337))
# Encrypt flag
sha512 = hashlib.sha512()
for _ in range(1000):
rnd = random.getrandbits(32)
sha512.update(str(rnd).encode('ascii'))
key = sha512.digest()
with open('../flag.txt', 'rb') as f:
flag = f.read()
enc = bytes(a ^ b for a, b in zip(flag, key))
print('Encrypted:', enc.hex())
#-*- coding:utf-8 -*-
'''
The idea is to find missing outputs (larger than 3133731337) that would satisfy
at least two relations of the MT generation scheme.
When no such numbers could be found, add new number based only on one relation,
but only when the position is very likely defined (i.e. between two numbers with known relations).
$ py3 wu.py 3 0
... ~1 hour later
$ py3 wu.py 3 1
GO ID 3 LEVEL 1
chk 0
match! 4506
decrypted: Balsn{T0_4cCept_0r_tO_r3jeCt__7hat_1s_tHe_Qu3sT1On}
'''
import random, sys
# https://github.com/kmyk/mersenne-twister-predictor
from mtpred import MT19937Predictor
from mtpred import MATRIX_A, UPPER_MASK, LOWER_MASK, mag01, tempering, untempering, N, M
MOD = 3133731337
ID = int(sys.argv[1])
LEVEL = int(sys.argv[2])
f0 = open("output.%d.txt" % ID).read().splitlines()
if LEVEL == 0:
f = open("output.%d.txt" % ID).read().splitlines()
else:
f = open("output.%d.txt_level%d" % (ID, LEVEL)).read().splitlines()
print("GO ID", ID, "LEVEL", LEVEL)
nums0 = map(int, f0[1:-1])
nums0 = list(nums0)
nums = map(int, f[1:-1])
nums = map(untempering, nums)
nums = list(nums)
rev = {v: i for i, v in enumerate(nums)}
RM = range(230, M+1)
RN = range(400, N+1)
sec = None
if ID == 4:
sec = [int(v, 16) for v in open("secret.txt").read().split()]
sec = list(map(untempering, sec))
skips = map(int, open("skips.txt").read().split())
if LEVEL >= 1:
numset = set(tempering(v) for v in nums)
for i in range(0, len(nums)-N):
if i % 1000 == 0:
print("chk", i)
predictor = MT19937Predictor()
for j in range(624):
predictor.setrandbits(tempering(nums[i+j]), 32)
test = [predictor.getrandbits(32) for _ in range(700)]
# print(i, off, len(test & numset))
if len(set(test) & numset) >= 700:
for xi in range(0x1337 * 3 // 2):
if predictor.getrandbits(32) == nums0[-1]:
import hashlib
print("match!", xi)
sha512 = hashlib.sha512()
for _ in range(1000):
rnd = predictor.getrandbits(32)
sha512.update(str(rnd).encode('ascii'))
key = sha512.digest()
enc = bytes.fromhex("0630543474e31b91c2b048ba3bf20226f0f55cec646d306233fb0fcc13a1dd6b3320cec24ec7571259fc70e29ebd405cdd61d6e8")
pt = bytes(a ^ b for a, b in zip(enc, key))
print('decrypted:', pt.decode("ascii"))
# Balsn{T0_4cCept_0r_tO_r3jeCt__7hat_1s_tHe_Qu3sT1On}
quit()
quit()
# different techniques
# lower prio more reliable
# when don't produce new values, apply less reliable ones
PRIO = 0
prevnums = len(nums)
for itr in range(10000):
if sec:
extra = set(nums) - set(sec)
if extra:
print("fail :(")
print(["%08x" % v for v in extra])
quit()
ms = []
for i, out in enumerate(nums[:-1]):
v1 = nums[i] & LOWER_MASK
v1 = (v1 >> 1) ^ mag01[v1 & 0x1]
for off_m in RM:
if i + off_m >= len(nums): break
v2 = nums[i + off_m]
v3 = v2 ^ v1
for flag in range(2):
if flag:
v3 ^= 1 << 30
if v3 not in rev: continue
off_n = rev[v3] - i
if off_n not in RN: continue
assert off_n in RN
assert off_m in RM
ms.append((i, i+off_m, i+off_n, flag))
ms.sort()
print("nums", len(nums), "matches", len(ms), "ratio", len(ms)/len(nums), "sub", len(nums)-len(ms))
print("PRIO", PRIO)
if PRIO == 0:
PRIO = 2
used1 = set()
used2 = set()
used3 = set()
for i, j, k, flag in ms:
used1.add(i)
used2.add(j)
used3.add(k)
def interok(i, j, k):
l = 0
r = len(ms)-1
last = l
while l < r:
mid = (l + r) // 2
ii,jj,kk,_ = ms[mid]
if ii < i:
last = mid
l = mid + 1
else:
r = mid - 1
m1 = last
l = 0
r = len(ms)-1
last = r
while l < r:
mid = (l + r) // 2
ii,jj,kk,_ = ms[mid]
if ii > i:
last = mid
r = mid - 1
else:
l = mid + 1
m2 = last
i1, j1, k1, _ = ms[m1]
i2, j2, k2, _ = ms[m2]
if not (i1 <= i <= i2): return False
if not (j1 <= j <= j2): return False
if not (k1 <= k <= k2): return False
return True
def check_useful(i_insert, val, as1=True, as2=True, as3=True, debug=False):
usefulcnt = 0
if PRIO >= 0: MINUSE = 3
if PRIO >= 1: MINUSE = 2
if PRIO >= 2: MINUSE = 1
if debug: print()
# check as start
if as1:
i1 = i_insert
for flag in range(2):
v1 = val
v1 = v1 & LOWER_MASK
v1 |= flag << 31
v1 = (v1 >> 1) ^ mag01[v1 & 0x1]
for off_m in RM:
i2 = i1 + off_m
if not (0 <= i2 < len(nums)): break
v2 = nums[i2]
v3 = v1 ^ v2
if v3 not in rev: continue
off_n = rev[v3] - i1
if off_n not in RN: continue
if (i1+off_m) in used2: continue
if (i1+off_n) in used3: continue
if not interok(i1, i1+off_m, i1+off_n): continue
if debug: print("us1", i1, i1+off_m, i1+off_n, "deltas", off_m, off_n)
usefulcnt += 1
if usefulcnt >= MINUSE: return True
# check as mid
if as2:
i2 = i_insert
for off_m in RM:
i1 = i2 - off_m
if not (0 <= i1 < len(nums)): break
for flag in range(2):
v1 = nums[i1]
v1 = v1 & LOWER_MASK
v1 |= flag << 31
v1 = (v1 >> 1) ^ mag01[v1 & 0x1]
v2 = val
v3 = v2 ^ v1
if v3 not in rev: continue
off_n = rev[v3] - i1
if off_n not in RN: continue
if (i1) in used1: continue
if (i1+off_n) in used3: continue
if not interok(i1, i1+off_m, i1+off_n): continue
if debug: print("us2", i1, i1+off_m, i1+off_n, "deltas", off_m, off_n)
usefulcnt += 1
if usefulcnt >= MINUSE: return True
# check as end
if as3:
i3 = i_insert
for off_n in RN:
i1 = i3 - off_n
if not (0 <= i1 < len(nums)): break
for flag in range(2):
v1 = nums[i1] & LOWER_MASK
v1 |= flag << 31
v1 = (v1 >> 1) ^ mag01[v1 & 0x1]
v3 = val
v2 = v3 ^ v1
if v2 not in rev: continue
off_m = rev[v2] - i1
if off_m not in RM: continue
if (i1) in used1: continue
if (i1+off_m) in used2: continue
if not interok(i1, i1+off_m, i1+off_n): continue
if debug: print("us3", i1, i1+off_m, i1+off_n, "deltas", off_m, off_n)
usefulcnt += 1
if usefulcnt >= MINUSE: return True
nforce = 1
found = 0
mininsert = len(nums)+N+10
for mat1, mat2 in reversed(list(zip(ms, ms[1:]))):
if found:
found = 0
continue
i1, m1, n1, flag1 = mat1
i2, m2, n2, flag2 = mat2
assert i1 <= i2, (mat1, mat2)
assert m1 <= m2, (mat1, mat2)
assert n1 <= n2, (mat1, mat2)
if i1 >= mininsert-N-10: continue
if i2 >= mininsert-N-10: continue
if m1 >= mininsert-N-10: continue
if m2 >= mininsert-N-10: continue
if n1 >= mininsert-N-10: continue
if n2 >= mininsert-N-10: continue
di = i2 - i1
dm = m2 - m1
dn = n2 - n1
if not (di == dm == dn == 1):
pass
if PRIO >= 3 and di == 2 and dm == 2 and dn == 1:
flag = nums[i1] >> 31
i = i1 + 1
j = m1 + 1
v1 = nums[i] & LOWER_MASK
v1 |= flag << 31
v1 = (v1 >> 1) ^ mag01[v1 & 0x1]
v2 = nums[j]
v3 = v1 ^ v2
k = i_insert = n2
if tempering(v3) < MOD: continue
if (j - i) not in RM: continue
if (k - i) not in RN: continue
print(mat1)
print(mat2)
print("force insert k", i, j, k, ":", "%08x" % v3)
# PRIO = 0
nforce -= 1
if nforce == 0:
PRIO = 0
for v in nums[i_insert:]:
del rev[v]
nums.insert(i_insert, v3)
mininsert = min(mininsert, i_insert)
for i in range(i_insert, len(nums)):
assert nums[i] not in rev
rev[nums[i]] = i
continue
if PRIO >= 3 and di == 2 and dm == 1 and dn == 2:
flag = nums[i1] >> 31
i = i1 + 1
k = n1 + 1
v1 = nums[i] & LOWER_MASK
v1 |= flag << 31
v1 = (v1 >> 1) ^ mag01[v1 & 0x1]
v3 = nums[k]
v2 = v1 ^ v3
j = i_insert = m2
if tempering(v2) < MOD: continue
if (j - i) not in RM: continue
if (k - i) not in RN: continue
print(mat1)
print(mat2)
print("force insert j", i, j, k, ":", "%08x" % v2)
nforce -= 1
if nforce == 0:
PRIO = 0
for v in nums[i_insert:]:
del rev[v]
nums.insert(i_insert, v2)
mininsert = min(mininsert, i_insert)
for i in range(i_insert, len(nums)):
assert nums[i] not in rev
rev[nums[i]] = i
continue
if PRIO >= 3 and di == 1 and dm == 2 and dn == 2:
flag = nums[i1] >> 31
j = m1 + 1
k = n1 + 1
v2 = nums[j]
v3 = nums[k]
v1 = v2 ^ v3
if v1 >> 31:
v1 ^= mag01[1]
v1 <<= 1
v1 |= 1
else:
v1 <<= 1
if v1 >> 31 != flag: continue
i = i_insert = i2
if tempering(v1) < MOD: continue
if (j - i) not in RM: continue
if (k - i) not in RN: continue
print(mat1)
print(mat2)
print("force insert i", i, j, k, ":", "%08x" % v1)
nforce -= 1
if nforce == 0:
PRIO = 0
for v in nums[i_insert:]:
del rev[v]
nums.insert(i_insert, v1)
mininsert = min(mininsert, i_insert)
for i in range(i_insert, len(nums)):
assert nums[i] not in rev
rev[nums[i]] = i
continue
if dn == 1:
found = 0
for i in range(i1+1, i2):
if found: break
for flag in range(2):
if found: break
v1 = nums[i] & LOWER_MASK
v1 |= flag << 31
v1 = (v1 >> 1) ^ mag01[v1 & 0x1]
for j in range(m1+1, m2):
if found: break
# if i == i1 and j == m1: continue
# if i == i2 and j == m2: continue
v2 = nums[j]
v3 = v1 ^ v2
if tempering(v3) < MOD: continue
k = i_insert = n2
if (j - i) not in RM: continue
if (k - i) not in RN: continue
assert v3 not in rev
if check_useful(i_insert, v3, as3=False, as2=False):
print(mat1)
print(mat2)
print("useful new k", i, j, k, "%08x (%08x)" % (v3, tempering(v3)), j-i, k-i)
for v in nums[i_insert:]:
del rev[v]
nums.insert(i_insert, v3)
mininsert = min(mininsert, i_insert)
for i in range(i_insert, len(nums)):
if nums[i] in rev:
for di in range(-5, 5):
print(i + di, "%08x" % nums[i+di])
assert nums[i] not in rev
rev[nums[i]] = i
found = 1
break
if found: continue
if di == 1:
# print()
# print(mat1)
# print(mat2)
found = 0
for k in range(n1+1, n2):
if found: break
v3 = nums[k]
for j in range(m1+1, m2):
# if k == n1 and j == m1: continue
# if k == n2 and j == m2: continue
if found: break
v2 = nums[j]
for flag in range(2):
if found: break
v1 = v3 ^ v2
if v1 >> 31:
v1 ^= mag01[1]
v1 <<= 1
v1 |= 1
else:
v1 <<= 1
v1 ^= flag << 31
if tempering(v1) < MOD: continue
i = i_insert = i2
if (j - i) not in RM: continue
if (k - i) not in RN: continue
assert v1 not in rev
if check_useful(i_insert, v1, as1=False):
print(mat1)
print(mat2)
print("useful new i", i, j, k, "%08x" % v1)
for v in nums[i_insert:]:
del rev[v]
nums.insert(i_insert, v1)
mininsert = min(mininsert, i_insert)
for i in range(i_insert, len(nums)):
if nums[i] in rev:
for di in range(-5, 5):
print(i + di, "%08x" % nums[i+di])
assert nums[i] not in rev, (i, "%08x" % nums[i], rev[nums[i]])
rev[nums[i]] = i
found = 1
break
if found: continue
if dm == 1:
found = 0
for k in range(n1+1, n2):
if found: break
v3 = nums[k]
for i in range(i1+1, i2):
# if k == n1 or i == i1: continue
# if k == n2 or i == i2: continue
if found: break
for flag in range(2):
if found: break
v1 = nums[i]
v1 = v1 & LOWER_MASK
v1 |= flag << 31
v1 = (v1 >> 1) ^ mag01[v1 & 0x1]
v2 = v1 ^ v3
if tempering(v2) < MOD: continue
j = i_insert = m2
if (j - i) not in RM: continue
if (k - i) not in RN: continue
assert v2 not in rev
if check_useful(i_insert, v2, as2=False, as3=False):
print(mat1)
print(mat2)
print("useful new j", i, j, k, "%08x (%08x)" % (v2, tempering(v2)))
for v in nums[i_insert:]:
del rev[v]
nums.insert(i_insert, v2)
mininsert = min(mininsert, i_insert)
for i in range(i_insert, len(nums)):
if nums[i] in rev:
for di in range(-5, 5):
print(i + di, "%08x" % nums[i+di])
assert nums[i] not in rev, (i, "%08x" % nums[i], rev[nums[i]])
rev[nums[i]] = i
found = 1
break
if found: continue
print("nums", len(nums))
print()
if len(nums) == prevnums:
print("nothing new...")
if PRIO >= 3:
break
PRIO += 1
else:
PRIO = 0
prevnums = len(nums)
ftest = open("output.%d.txt_level%d" % (ID, LEVEL+1), "w")
ftest.write("first\n")
for v in nums:
ftest.write("%s\n" % str(tempering(v)))
ftest.write("last")
ftest.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment