Skip to content

Instantly share code, notes, and snippets.

@y011d4
Created November 9, 2022 01:46
Show Gist options
  • Save y011d4/0b79a71d0a15f6796ea7fc3759274e01 to your computer and use it in GitHub Desktop.
Save y011d4/0b79a71d0a15f6796ea7fc3759274e01 to your computer and use it in GitHub Desktop.
N1CTF brand_new_checkin
import random
from Crypto.Util.number import bytes_to_long, long_to_bytes
from z3 import *
N = 624
M = 397
MATRIX_A = 0x9908B0DF
UPPER_MASK = 0x80000000
LOWER_MASK = 0x7FFFFFFF
def bit_shift_right_xor_rev(x, shift):
i = 1
y = x
while i * shift < 32:
if type(y) == int:
z = y >> shift
else:
z = LShR(y, shift)
y = x ^ z
i += 1
return y
def bit_shift_left_xor_rev(x, shift, mask):
i = 1
y = x
while i * shift < 32:
z = y << shift
y = x ^ (z & mask)
i += 1
return y
def untemper(x):
x = bit_shift_right_xor_rev(x, 18)
x = bit_shift_left_xor_rev(x, 15, 0xEFC60000)
x = bit_shift_left_xor_rev(x, 7, 0x9D2C5680)
x = bit_shift_right_xor_rev(x, 11)
return x
def update_mt(mt):
N = 624
M = 397
MATRIX_A = 0x9908B0DF
UPPER_MASK = 0x80000000
LOWER_MASK = 0x7FFFFFFF
for kk in range(N - M):
y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK)
if type(y) == int:
mt[kk] = mt[kk + M] ^ (y >> 1) ^ (y % 2) * MATRIX_A
else:
mt[kk] = mt[kk + M] ^ LShR(y, 1) ^ (y % 2) * MATRIX_A
for kk in range(N - M, N - 1):
y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK)
if type(y) == int:
mt[kk] = mt[kk + (M - N)] ^ (y >> 1) ^ (y % 2) * MATRIX_A
else:
mt[kk] = mt[kk + (M - N)] ^ LShR(y, 1) ^ (y % 2) * MATRIX_A
y = (mt[N - 1] & UPPER_MASK) | (mt[0] & LOWER_MASK)
if type(y) == int:
mt[N - 1] = mt[M - 1] ^ (y >> 1) ^ (y % 2) * MATRIX_A
else:
mt[N - 1] = mt[M - 1] ^ LShR(y, 1) ^ (y % 2) * MATRIX_A
s, t, n = pubkey
enc = b""
for idx in range(len(c)):
c1, c2 = c[idx]
if c1 == c2 == 0:
enc += b"\x00"
continue
for m in range(1, 256):
m_inv = pow(m, -1, n)
if pow(c1 * m_inv, t, n) == pow(c2 * m_inv, s, n):
enc += long_to_bytes(m)
print(enc)
break
rs = []
for i, ci in enumerate(c):
m = enc[i]
if m == 0:
break
r_s = int(ci[0] * pow(m, -1, n) % n)
r_t = int(ci[1] * pow(m, -1, n) % n)
# gcd(r^s, r^t)
tmp_s, tmp_t = s, t
tmp_r_s, tmp_r_t = r_s, r_t
while True:
if tmp_s % tmp_t == 0:
break
tmp_r_s, tmp_r_t = tmp_r_t, int(
tmp_r_s * pow(tmp_r_t, -(tmp_s // tmp_t), n) % n
)
tmp_s, tmp_t = tmp_t, int(tmp_s % tmp_t)
assert tmp_t == 1
r = tmp_r_t
rs.append(int(r))
def state_idx_to_rs_idx(idx):
"""
0 -> (0, 0)
1 -> (0, 1)
...
31 -> (0, 31)
32 -> (1, 0)
...
"""
tmp = idx
return (tmp // 32, tmp % 32)
r_n_list = [None] * (N // 32 * 2 + 2)
for state_idx in range(624):
print(state_idx)
idx0 = state_idx_to_rs_idx(state_idx + 0)
r0 = rs[idx0[0]]
idx1 = state_idx_to_rs_idx(state_idx + 1)
r1 = rs[idx1[0]]
idx2 = state_idx_to_rs_idx(state_idx + M)
r2 = rs[idx2[0]]
idx3 = state_idx_to_rs_idx(state_idx + N)
r3 = rs[idx3[0]]
cands = []
for i0 in range((2 ** 1024 - r0) // n + 1):
tmp_r0 = r0 + i0 * n
state_0 = untemper((tmp_r0 >> (32 * idx0[1])) % 2 ** 32)
for i1 in range((2 ** 1024 - r1) // n + 1):
tmp_r1 = r1 + i1 * n
state_1 = untemper((tmp_r1 >> (32 * idx1[1])) % 2 ** 32)
for i2 in range((2 ** 1024 - r2) // n + 1):
tmp_r2 = r2 + i2 * n
state_2 = untemper((tmp_r2 >> (32 * idx2[1])) % 2 ** 32)
for i3 in range((2 ** 1024 - r3) // n + 1):
tmp_r3 = r3 + i3 * n
state_3 = untemper((tmp_r3 >> (32 * idx3[1])) % 2 ** 32)
y = (state_0 & UPPER_MASK) | (state_1 & LOWER_MASK)
if state_3 == state_2 ^ (y >> 1) ^ (y % 2) * MATRIX_A:
cands.append((i0, i1, i2, i3))
print("found", i0, i1, i2, i3)
if len(cands) == 1:
i0, i1, i2, i3 = cands[0]
if r_n_list[idx0[0]] is not None:
assert r_n_list[idx0[0]] == i0
if r_n_list[idx1[0]] is not None:
assert r_n_list[idx1[0]] == i1
if r_n_list[idx2[0]] is not None:
assert r_n_list[idx2[0]] == i2
if r_n_list[idx3[0]] is not None:
assert r_n_list[idx3[0]] == i3
r_n_list[idx0[0]] = i0
r_n_list[idx1[0]] = i1
r_n_list[idx2[0]] = i2
r_n_list[idx3[0]] = i3
state = [None] * (N * 2)
for state_idx in range(624):
print(state_idx)
idx0 = state_idx_to_rs_idx(state_idx + 0)
r0 = rs[idx0[0]] + r_n_list[idx0[0]] * n
idx1 = state_idx_to_rs_idx(state_idx + 1)
r1 = rs[idx1[0]] + r_n_list[idx1[0]] * n
idx2 = state_idx_to_rs_idx(state_idx + M)
r2 = rs[idx2[0]] + r_n_list[idx2[0]] * n
idx3 = state_idx_to_rs_idx(state_idx + N)
r3 = rs[idx3[0]] + r_n_list[idx3[0]] * n
state_0 = untemper((r0 >> (32 * idx0[1])) % 2 ** 32)
state_1 = untemper((r1 >> (32 * idx1[1])) % 2 ** 32)
state_2 = untemper((r2 >> (32 * idx2[1])) % 2 ** 32)
state_3 = untemper((r3 >> (32 * idx3[1])) % 2 ** 32)
y = (state_0 & UPPER_MASK) | (state_1 & LOWER_MASK)
assert state_3 == state_2 ^ (y >> 1) ^ (y % 2) * MATRIX_A
if state[state_idx + 0] is not None:
assert state[state_idx + 0] == state_0
if state[state_idx + 1] is not None:
assert state[state_idx + 1] == state_1
if state[state_idx + M] is not None:
assert state[state_idx + M] == state_2
if state[state_idx + N] is not None:
assert state[state_idx + N] == state_3
state[state_idx + 0] = state_0
state[state_idx + 1] = state_1
state[state_idx + M] = state_2
state[state_idx + N] = state_3
state = state[:N]
prev_state = [BitVec(f"x{i}", 32) for i in range(N)]
current_state = prev_state.copy()
update_mt(current_state)
s = Solver()
for i in range(N):
s.add(current_state[i] == state[i])
s.check()
m = s.model()
prev_state = [m[s].as_long() for s in prev_state]
random.setstate((3, tuple(prev_state[:N] + [N - 64]), None))
a = random.getrandbits(1024)
s = random.getrandbits(1024)
# bt = -sa mod phi
# b = phi + 1 - a
# (phi + 1 - a)t = -sa mod phi
# (1 - a)t = -sa
# sa + (1 - a) t = k phi
# T = k phi
T = s * a + (1 - a) * t
d = pow(0x10001, -1, T)
print(long_to_bytes(pow(bytes_to_long(enc), d, n)))
# n1ctf{5255840f-9140-4479-950f-a3c03fe7f4cd}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment