Skip to content

Instantly share code, notes, and snippets.

@maple3142
Created July 29, 2022 13:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maple3142/6f4d3a6c19a58c1630e7f1834c4b3651 to your computer and use it in GitHub Desktop.
Save maple3142/6f4d3a6c19a58c1630e7f1834c4b3651 to your computer and use it in GitHub Desktop.
# Based on https://utaha1228.github.io/ctf-note/2021/11/21/BalsnCTF-2021/
import random
from tqdm import tqdm
class MT19937:
(w, n, m, r) = (32, 624, 397, 31)
a = 0x9908B0DF
(u, d) = (11, 0xFFFFFFFF)
(s, b) = (7, 0x9D2C5680)
(t, c) = (15, 0xEFC60000)
l = 18
lowerMask = (1 << r) - 1
mask = (1 << w) - 1
upperMask = mask ^ lowerMask
f = 1812433253
def __init__(self, seed):
self.states = [seed]
self.index = MT19937.n
for i in range(1, self.n):
self.states.append(
self.mask
& (
i
+ self.f
* (self.states[i - 1] ^ (self.states[i - 1] >> (self.w - 2)))
)
)
def temper(self, num):
num = num ^ ((num >> MT19937.u) & MT19937.d)
num = num ^ ((num << MT19937.s) & MT19937.b)
num = num ^ ((num << MT19937.t) & MT19937.c)
num = num ^ (num >> MT19937.l)
return num
def rand(self):
if self.index >= MT19937.n:
self.twist()
y = self.states[self.index]
self.index += 1
return self.temper(y)
def twist(self):
for i in range(MT19937.n):
x = (self.states[i] & MT19937.upperMask) ^ (
self.states[(i + 1) % MT19937.n] & MT19937.lowerMask
)
xA = x >> 1
if x & 1:
xA = xA ^ self.a
self.states[i] = self.states[(i + MT19937.m) % MT19937.n] ^ xA
self.index = 0
class bitwiseMT19937:
def __init__(self):
self.states = [[1 << (i * 32 + j) for j in range(32)] for i in range(MT19937.n)]
self.index = MT19937.n
def temper(self, num):
ret = num[:]
for i in range(32):
if (MT19937.d >> i) & 1:
if i + MT19937.u < 32:
ret[i] ^= ret[i + MT19937.u]
for i in range(31, -1, -1):
if (MT19937.b >> i) & 1:
if i - MT19937.s >= 0:
ret[i] ^= ret[i - MT19937.s]
for i in range(31, -1, -1):
if (MT19937.c >> i) & 1:
if i - MT19937.t >= 0:
ret[i] ^= ret[i - MT19937.t]
for i in range(32):
if i + MT19937.l < 32:
ret[i] ^= ret[i + MT19937.l]
# num = num ^ ((num >> MT19937.u) & MT19937.d)
# num = num ^ ((num << MT19937.s) & MT19937.b)
# num = num ^ ((num << MT19937.t) & MT19937.c)
# num = num ^ (num >> MT19937.l)
return ret
def rand(self):
if self.index >= MT19937.n:
self.twist()
y = self.states[self.index]
self.index += 1
return self.temper(y)
def twist(self):
for i in range(MT19937.n):
x = self.states[(i + 1) % MT19937.n][:-1] + self.states[i][-1:]
xA = x[1:] + [0]
for t in range(32):
if (MT19937.a >> t) & 1:
xA[t] ^= x[0]
# x = (self.states[i] & MT19937.upperMask) ^ (self.states[(i + 1) % MT19937.n] & MT19937.lowerMask)
# xA = x >> 1
# if x & 1:
# xA = xA ^ self.a
for t in range(32):
self.states[i][t] = self.states[(i + MT19937.m) % MT19937.n][t] ^ xA[t]
# self.states[i] = self.states[(i + MT19937.m) % MT19937.n] ^ xA
self.index = 0
def count(x):
ret = 0
for i in range(20000):
ret ^= (x >> i) & 1
return ret
TOTAL_BITS = 19968
class Untwister:
def __init__(self):
self.linear_base = [(-1, -1) for _ in range(TOTAL_BITS)]
self.total = 0
self.bitRng = bitwiseMT19937()
self.num_sub = 0
self.solved = False
def add(self, bits, output):
while bits:
idx = bits.bit_length() - 1
if self.linear_base[idx] == (-1, -1):
self.linear_base[idx] = (bits, output)
self.total += 1
return
else:
bits ^= self.linear_base[idx][0]
output ^= self.linear_base[idx][1]
def submit(self, bits):
# bits is a bitstring from MSB to LSB
if self.solved:
raise Exception("Must not submit after solved")
if len(bits) != 32:
raise Exception("Invalid bits length")
if not all([x in ("0", "1", "?") for x in bits]):
raise Exception("Invalid bits")
symbits = self.bitRng.rand()
for i in range(len(bits)):
if bits[i] != "?":
self.add(symbits[31 - i], int(bits[i]))
self.num_sub += 1
def solve(self):
for i in range(TOTAL_BITS):
if self.linear_base[i] == (-1, -1):
continue
assert (self.linear_base[i][0] & ((1 << 31) - 1)) == 0
for i in tqdm(range(TOTAL_BITS)):
if self.linear_base[i] == (-1, -1):
self.linear_base[i] = (1 << i, 0)
continue
mask = self.linear_base[i][0] ^ (1 << i)
while mask:
idx = mask.bit_length() - 1
self.linear_base[i] = (
self.linear_base[i][0] ^ self.linear_base[idx][0],
self.linear_base[i][1] ^ self.linear_base[idx][1],
)
mask ^= 1 << idx
for i in range(TOTAL_BITS):
assert self.linear_base[i][0] == (1 << i)
def get_random(self, skip_states=True):
if not self.solved:
self.solve()
self.solved = True
stateLong = sum((1 << i) * self.linear_base[i][1] for i in range(TOTAL_BITS))
states = []
for _ in range(MT19937.n):
states.append(stateLong & ((1 << 32) - 1))
stateLong >>= 32
pyr = random.Random()
pyr.setstate((3, tuple(states) + (624,), None))
if skip_states:
for _ in range(self.num_sub):
pyr.getrandbits(32)
return pyr
# with implemented version
# rng = MT19937(48763)
# rng.states = states[:]
# if skip_states:
# for _ in range(self.num_sub):
# rng.rand()
# return rng
if __name__ == "__main__":
r1 = random.Random(1234)
ut = Untwister()
n = 624 * (32 // 8)
for _ in range(n):
v = r1.getrandbits(8)
ut.submit(f"{v:08b}" + "?" * 24)
r2 = ut.get_random()
for _ in range(128):
assert r2.getrandbits(32) == r1.getrandbits(32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment