Created
July 29, 2022 13:25
-
-
Save maple3142/6f4d3a6c19a58c1630e7f1834c4b3651 to your computer and use it in GitHub Desktop.
MT19937 reverser based on https://utaha1228.github.io/ctf-note/2021/11/21/BalsnCTF-2021/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Based on https://utaha1228.github.io/ctf-note/2021/11/21/BalsnCTF-2021/ | |
import random | |
from tqdm import tqdm | |
class MT19937: | |
(w, n, m, r) = (32, 624, 397, 31) | |
a = 0x9908B0DF | |
(u, d) = (11, 0xFFFFFFFF) | |
(s, b) = (7, 0x9D2C5680) | |
(t, c) = (15, 0xEFC60000) | |
l = 18 | |
lowerMask = (1 << r) - 1 | |
mask = (1 << w) - 1 | |
upperMask = mask ^ lowerMask | |
f = 1812433253 | |
def __init__(self, seed): | |
self.states = [seed] | |
self.index = MT19937.n | |
for i in range(1, self.n): | |
self.states.append( | |
self.mask | |
& ( | |
i | |
+ self.f | |
* (self.states[i - 1] ^ (self.states[i - 1] >> (self.w - 2))) | |
) | |
) | |
def temper(self, num): | |
num = num ^ ((num >> MT19937.u) & MT19937.d) | |
num = num ^ ((num << MT19937.s) & MT19937.b) | |
num = num ^ ((num << MT19937.t) & MT19937.c) | |
num = num ^ (num >> MT19937.l) | |
return num | |
def rand(self): | |
if self.index >= MT19937.n: | |
self.twist() | |
y = self.states[self.index] | |
self.index += 1 | |
return self.temper(y) | |
def twist(self): | |
for i in range(MT19937.n): | |
x = (self.states[i] & MT19937.upperMask) ^ ( | |
self.states[(i + 1) % MT19937.n] & MT19937.lowerMask | |
) | |
xA = x >> 1 | |
if x & 1: | |
xA = xA ^ self.a | |
self.states[i] = self.states[(i + MT19937.m) % MT19937.n] ^ xA | |
self.index = 0 | |
class bitwiseMT19937: | |
def __init__(self): | |
self.states = [[1 << (i * 32 + j) for j in range(32)] for i in range(MT19937.n)] | |
self.index = MT19937.n | |
def temper(self, num): | |
ret = num[:] | |
for i in range(32): | |
if (MT19937.d >> i) & 1: | |
if i + MT19937.u < 32: | |
ret[i] ^= ret[i + MT19937.u] | |
for i in range(31, -1, -1): | |
if (MT19937.b >> i) & 1: | |
if i - MT19937.s >= 0: | |
ret[i] ^= ret[i - MT19937.s] | |
for i in range(31, -1, -1): | |
if (MT19937.c >> i) & 1: | |
if i - MT19937.t >= 0: | |
ret[i] ^= ret[i - MT19937.t] | |
for i in range(32): | |
if i + MT19937.l < 32: | |
ret[i] ^= ret[i + MT19937.l] | |
# num = num ^ ((num >> MT19937.u) & MT19937.d) | |
# num = num ^ ((num << MT19937.s) & MT19937.b) | |
# num = num ^ ((num << MT19937.t) & MT19937.c) | |
# num = num ^ (num >> MT19937.l) | |
return ret | |
def rand(self): | |
if self.index >= MT19937.n: | |
self.twist() | |
y = self.states[self.index] | |
self.index += 1 | |
return self.temper(y) | |
def twist(self): | |
for i in range(MT19937.n): | |
x = self.states[(i + 1) % MT19937.n][:-1] + self.states[i][-1:] | |
xA = x[1:] + [0] | |
for t in range(32): | |
if (MT19937.a >> t) & 1: | |
xA[t] ^= x[0] | |
# x = (self.states[i] & MT19937.upperMask) ^ (self.states[(i + 1) % MT19937.n] & MT19937.lowerMask) | |
# xA = x >> 1 | |
# if x & 1: | |
# xA = xA ^ self.a | |
for t in range(32): | |
self.states[i][t] = self.states[(i + MT19937.m) % MT19937.n][t] ^ xA[t] | |
# self.states[i] = self.states[(i + MT19937.m) % MT19937.n] ^ xA | |
self.index = 0 | |
def count(x): | |
ret = 0 | |
for i in range(20000): | |
ret ^= (x >> i) & 1 | |
return ret | |
TOTAL_BITS = 19968 | |
class Untwister: | |
def __init__(self): | |
self.linear_base = [(-1, -1) for _ in range(TOTAL_BITS)] | |
self.total = 0 | |
self.bitRng = bitwiseMT19937() | |
self.num_sub = 0 | |
self.solved = False | |
def add(self, bits, output): | |
while bits: | |
idx = bits.bit_length() - 1 | |
if self.linear_base[idx] == (-1, -1): | |
self.linear_base[idx] = (bits, output) | |
self.total += 1 | |
return | |
else: | |
bits ^= self.linear_base[idx][0] | |
output ^= self.linear_base[idx][1] | |
def submit(self, bits): | |
# bits is a bitstring from MSB to LSB | |
if self.solved: | |
raise Exception("Must not submit after solved") | |
if len(bits) != 32: | |
raise Exception("Invalid bits length") | |
if not all([x in ("0", "1", "?") for x in bits]): | |
raise Exception("Invalid bits") | |
symbits = self.bitRng.rand() | |
for i in range(len(bits)): | |
if bits[i] != "?": | |
self.add(symbits[31 - i], int(bits[i])) | |
self.num_sub += 1 | |
def solve(self): | |
for i in range(TOTAL_BITS): | |
if self.linear_base[i] == (-1, -1): | |
continue | |
assert (self.linear_base[i][0] & ((1 << 31) - 1)) == 0 | |
for i in tqdm(range(TOTAL_BITS)): | |
if self.linear_base[i] == (-1, -1): | |
self.linear_base[i] = (1 << i, 0) | |
continue | |
mask = self.linear_base[i][0] ^ (1 << i) | |
while mask: | |
idx = mask.bit_length() - 1 | |
self.linear_base[i] = ( | |
self.linear_base[i][0] ^ self.linear_base[idx][0], | |
self.linear_base[i][1] ^ self.linear_base[idx][1], | |
) | |
mask ^= 1 << idx | |
for i in range(TOTAL_BITS): | |
assert self.linear_base[i][0] == (1 << i) | |
def get_random(self, skip_states=True): | |
if not self.solved: | |
self.solve() | |
self.solved = True | |
stateLong = sum((1 << i) * self.linear_base[i][1] for i in range(TOTAL_BITS)) | |
states = [] | |
for _ in range(MT19937.n): | |
states.append(stateLong & ((1 << 32) - 1)) | |
stateLong >>= 32 | |
pyr = random.Random() | |
pyr.setstate((3, tuple(states) + (624,), None)) | |
if skip_states: | |
for _ in range(self.num_sub): | |
pyr.getrandbits(32) | |
return pyr | |
# with implemented version | |
# rng = MT19937(48763) | |
# rng.states = states[:] | |
# if skip_states: | |
# for _ in range(self.num_sub): | |
# rng.rand() | |
# return rng | |
if __name__ == "__main__": | |
r1 = random.Random(1234) | |
ut = Untwister() | |
n = 624 * (32 // 8) | |
for _ in range(n): | |
v = r1.getrandbits(8) | |
ut.submit(f"{v:08b}" + "?" * 24) | |
r2 = ut.get_random() | |
for _ in range(128): | |
assert r2.getrandbits(32) == r1.getrandbits(32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment