Created
November 9, 2022 01:46
-
-
Save y011d4/0b79a71d0a15f6796ea7fc3759274e01 to your computer and use it in GitHub Desktop.
N1CTF brand_new_checkin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from Crypto.Util.number import bytes_to_long, long_to_bytes | |
from z3 import * | |
N = 624 | |
M = 397 | |
MATRIX_A = 0x9908B0DF | |
UPPER_MASK = 0x80000000 | |
LOWER_MASK = 0x7FFFFFFF | |
def bit_shift_right_xor_rev(x, shift): | |
i = 1 | |
y = x | |
while i * shift < 32: | |
if type(y) == int: | |
z = y >> shift | |
else: | |
z = LShR(y, shift) | |
y = x ^ z | |
i += 1 | |
return y | |
def bit_shift_left_xor_rev(x, shift, mask): | |
i = 1 | |
y = x | |
while i * shift < 32: | |
z = y << shift | |
y = x ^ (z & mask) | |
i += 1 | |
return y | |
def untemper(x): | |
x = bit_shift_right_xor_rev(x, 18) | |
x = bit_shift_left_xor_rev(x, 15, 0xEFC60000) | |
x = bit_shift_left_xor_rev(x, 7, 0x9D2C5680) | |
x = bit_shift_right_xor_rev(x, 11) | |
return x | |
def update_mt(mt): | |
N = 624 | |
M = 397 | |
MATRIX_A = 0x9908B0DF | |
UPPER_MASK = 0x80000000 | |
LOWER_MASK = 0x7FFFFFFF | |
for kk in range(N - M): | |
y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK) | |
if type(y) == int: | |
mt[kk] = mt[kk + M] ^ (y >> 1) ^ (y % 2) * MATRIX_A | |
else: | |
mt[kk] = mt[kk + M] ^ LShR(y, 1) ^ (y % 2) * MATRIX_A | |
for kk in range(N - M, N - 1): | |
y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK) | |
if type(y) == int: | |
mt[kk] = mt[kk + (M - N)] ^ (y >> 1) ^ (y % 2) * MATRIX_A | |
else: | |
mt[kk] = mt[kk + (M - N)] ^ LShR(y, 1) ^ (y % 2) * MATRIX_A | |
y = (mt[N - 1] & UPPER_MASK) | (mt[0] & LOWER_MASK) | |
if type(y) == int: | |
mt[N - 1] = mt[M - 1] ^ (y >> 1) ^ (y % 2) * MATRIX_A | |
else: | |
mt[N - 1] = mt[M - 1] ^ LShR(y, 1) ^ (y % 2) * MATRIX_A | |
s, t, n = pubkey | |
enc = b"" | |
for idx in range(len(c)): | |
c1, c2 = c[idx] | |
if c1 == c2 == 0: | |
enc += b"\x00" | |
continue | |
for m in range(1, 256): | |
m_inv = pow(m, -1, n) | |
if pow(c1 * m_inv, t, n) == pow(c2 * m_inv, s, n): | |
enc += long_to_bytes(m) | |
print(enc) | |
break | |
rs = [] | |
for i, ci in enumerate(c): | |
m = enc[i] | |
if m == 0: | |
break | |
r_s = int(ci[0] * pow(m, -1, n) % n) | |
r_t = int(ci[1] * pow(m, -1, n) % n) | |
# gcd(r^s, r^t) | |
tmp_s, tmp_t = s, t | |
tmp_r_s, tmp_r_t = r_s, r_t | |
while True: | |
if tmp_s % tmp_t == 0: | |
break | |
tmp_r_s, tmp_r_t = tmp_r_t, int( | |
tmp_r_s * pow(tmp_r_t, -(tmp_s // tmp_t), n) % n | |
) | |
tmp_s, tmp_t = tmp_t, int(tmp_s % tmp_t) | |
assert tmp_t == 1 | |
r = tmp_r_t | |
rs.append(int(r)) | |
def state_idx_to_rs_idx(idx): | |
""" | |
0 -> (0, 0) | |
1 -> (0, 1) | |
... | |
31 -> (0, 31) | |
32 -> (1, 0) | |
... | |
""" | |
tmp = idx | |
return (tmp // 32, tmp % 32) | |
r_n_list = [None] * (N // 32 * 2 + 2) | |
for state_idx in range(624): | |
print(state_idx) | |
idx0 = state_idx_to_rs_idx(state_idx + 0) | |
r0 = rs[idx0[0]] | |
idx1 = state_idx_to_rs_idx(state_idx + 1) | |
r1 = rs[idx1[0]] | |
idx2 = state_idx_to_rs_idx(state_idx + M) | |
r2 = rs[idx2[0]] | |
idx3 = state_idx_to_rs_idx(state_idx + N) | |
r3 = rs[idx3[0]] | |
cands = [] | |
for i0 in range((2 ** 1024 - r0) // n + 1): | |
tmp_r0 = r0 + i0 * n | |
state_0 = untemper((tmp_r0 >> (32 * idx0[1])) % 2 ** 32) | |
for i1 in range((2 ** 1024 - r1) // n + 1): | |
tmp_r1 = r1 + i1 * n | |
state_1 = untemper((tmp_r1 >> (32 * idx1[1])) % 2 ** 32) | |
for i2 in range((2 ** 1024 - r2) // n + 1): | |
tmp_r2 = r2 + i2 * n | |
state_2 = untemper((tmp_r2 >> (32 * idx2[1])) % 2 ** 32) | |
for i3 in range((2 ** 1024 - r3) // n + 1): | |
tmp_r3 = r3 + i3 * n | |
state_3 = untemper((tmp_r3 >> (32 * idx3[1])) % 2 ** 32) | |
y = (state_0 & UPPER_MASK) | (state_1 & LOWER_MASK) | |
if state_3 == state_2 ^ (y >> 1) ^ (y % 2) * MATRIX_A: | |
cands.append((i0, i1, i2, i3)) | |
print("found", i0, i1, i2, i3) | |
if len(cands) == 1: | |
i0, i1, i2, i3 = cands[0] | |
if r_n_list[idx0[0]] is not None: | |
assert r_n_list[idx0[0]] == i0 | |
if r_n_list[idx1[0]] is not None: | |
assert r_n_list[idx1[0]] == i1 | |
if r_n_list[idx2[0]] is not None: | |
assert r_n_list[idx2[0]] == i2 | |
if r_n_list[idx3[0]] is not None: | |
assert r_n_list[idx3[0]] == i3 | |
r_n_list[idx0[0]] = i0 | |
r_n_list[idx1[0]] = i1 | |
r_n_list[idx2[0]] = i2 | |
r_n_list[idx3[0]] = i3 | |
state = [None] * (N * 2) | |
for state_idx in range(624): | |
print(state_idx) | |
idx0 = state_idx_to_rs_idx(state_idx + 0) | |
r0 = rs[idx0[0]] + r_n_list[idx0[0]] * n | |
idx1 = state_idx_to_rs_idx(state_idx + 1) | |
r1 = rs[idx1[0]] + r_n_list[idx1[0]] * n | |
idx2 = state_idx_to_rs_idx(state_idx + M) | |
r2 = rs[idx2[0]] + r_n_list[idx2[0]] * n | |
idx3 = state_idx_to_rs_idx(state_idx + N) | |
r3 = rs[idx3[0]] + r_n_list[idx3[0]] * n | |
state_0 = untemper((r0 >> (32 * idx0[1])) % 2 ** 32) | |
state_1 = untemper((r1 >> (32 * idx1[1])) % 2 ** 32) | |
state_2 = untemper((r2 >> (32 * idx2[1])) % 2 ** 32) | |
state_3 = untemper((r3 >> (32 * idx3[1])) % 2 ** 32) | |
y = (state_0 & UPPER_MASK) | (state_1 & LOWER_MASK) | |
assert state_3 == state_2 ^ (y >> 1) ^ (y % 2) * MATRIX_A | |
if state[state_idx + 0] is not None: | |
assert state[state_idx + 0] == state_0 | |
if state[state_idx + 1] is not None: | |
assert state[state_idx + 1] == state_1 | |
if state[state_idx + M] is not None: | |
assert state[state_idx + M] == state_2 | |
if state[state_idx + N] is not None: | |
assert state[state_idx + N] == state_3 | |
state[state_idx + 0] = state_0 | |
state[state_idx + 1] = state_1 | |
state[state_idx + M] = state_2 | |
state[state_idx + N] = state_3 | |
state = state[:N] | |
prev_state = [BitVec(f"x{i}", 32) for i in range(N)] | |
current_state = prev_state.copy() | |
update_mt(current_state) | |
s = Solver() | |
for i in range(N): | |
s.add(current_state[i] == state[i]) | |
s.check() | |
m = s.model() | |
prev_state = [m[s].as_long() for s in prev_state] | |
random.setstate((3, tuple(prev_state[:N] + [N - 64]), None)) | |
a = random.getrandbits(1024) | |
s = random.getrandbits(1024) | |
# bt = -sa mod phi | |
# b = phi + 1 - a | |
# (phi + 1 - a)t = -sa mod phi | |
# (1 - a)t = -sa | |
# sa + (1 - a) t = k phi | |
# T = k phi | |
T = s * a + (1 - a) * t | |
d = pow(0x10001, -1, T) | |
print(long_to_bytes(pow(bytes_to_long(enc), d, n))) | |
# n1ctf{5255840f-9140-4479-950f-a3c03fe7f4cd} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment