Skip to content

Instantly share code, notes, and snippets.

@y011d4
Last active May 22, 2022 13:12
Show Gist options
  • Save y011d4/f21cf0139e5c4a537db768297ca6b935 to your computer and use it in GitHub Desktop.
Save y011d4/f21cf0139e5c4a537db768297ca6b935 to your computer and use it in GitHub Desktop.
my solver for zer0lfsr++ in 0CTF/TCTF 2021
import random
from z3 import *
from tqdm import tqdm
from functools import lru_cache
from collections import Counter
from itertools import combinations
import time
def _prod(L):
p = 1
for x in L:
p *= x
return p
def _sum(L):
s = 0
for x in L:
s ^^= x
return s
def b2n(x):
return int.from_bytes(x, "big")
def n2l(x, l):
return list(map(int, "{{0:0{}b}}".format(l).format(x)))
def split(x, n, l):
return [(x >> (i * l)) % 2 ** l for i in range(n)][::-1]
def combine(x, n, l):
return sum([x[i] << (l * (n - i - 1)) for i in range(n)])
class Generator1:
def __init__(self, key: list):
assert len(key) == 64
self.NFSR = key[:48]
self.LFSR = key[48:]
self.TAP = [0, 1, 12, 15]
self.TAP2 = [
[2],
[5],
[9],
[15],
[22],
[26],
[39],
[26, 30],
[5, 9],
[15, 22, 26],
[15, 22, 39],
[9, 22, 26, 39],
]
self.h_IN = [2, 4, 7, 15, 27]
self.h_OUT = [[1], [3], [0, 3], [0, 1, 2], [0, 2, 3], [0, 2, 4], [0, 1, 2, 4]]
def g(self):
x = self.NFSR
return _sum(_prod(x[i] for i in j) for j in self.TAP2)
def h(self):
x = [self.LFSR[i] for i in self.h_IN[:-1]] + [self.NFSR[self.h_IN[-1]]]
return _sum(_prod(x[i] for i in j) for j in self.h_OUT)
def f(self):
return _sum([self.NFSR[0], self.h()])
def clock(self):
o = self.f()
self.NFSR = self.NFSR[1:] + [self.LFSR[0] ^^ self.g()]
self.LFSR = self.LFSR[1:] + [_sum(self.LFSR[i] for i in self.TAP)]
return o
class Generator2:
def __init__(self, key):
assert len(key) == 64
self.NFSR = key[:16]
self.LFSR = key[16:]
self.TAP = [0, 35]
self.f_IN = [0, 10, 20, 30, 40, 47]
self.f_OUT = [
[0, 1, 2, 3],
[0, 1, 2, 4, 5],
[0, 1, 2, 5],
[0, 1, 2],
[0, 1, 3, 4, 5],
[0, 1, 3, 5],
[0, 1, 3],
[0, 1, 4],
[0, 1, 5],
[0, 2, 3, 4, 5],
[0, 2, 3],
[0, 3, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4],
[1, 2, 3, 5],
[1, 2],
[1, 3, 5],
[1, 3],
[1, 4],
[1],
[2, 4, 5],
[2, 4],
[2],
[3, 4],
[4, 5],
[4],
[5],
]
self.TAP2 = [[0, 3, 7], [1, 11, 13, 15], [2, 9]]
self.h_IN = [0, 2, 4, 6, 8, 13, 14]
self.h_OUT = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 4, 6], [1, 3, 4]]
def f(self):
x = [self.LFSR[i] for i in self.f_IN]
return _sum(_prod(x[i] for i in j) for j in self.f_OUT)
def h(self):
x = [self.NFSR[i] for i in self.h_IN]
return _sum(_prod(x[i] for i in j) for j in self.h_OUT)
def g(self):
x = self.NFSR
return _sum(_prod(x[i] for i in j) for j in self.TAP2)
def clock(self):
self.LFSR = self.LFSR[1:] + [_sum(self.LFSR[i] for i in self.TAP)]
self.NFSR = self.NFSR[1:] + [self.LFSR[1] ^^ self.g()]
return self.f() ^^ self.h()
class Generator3:
def __init__(self, key: list):
assert len(key) == 64
self.LFSR = key
self.TAP = [0, 55]
self.f_IN = [0, 8, 16, 24, 32, 40, 63]
self.f_OUT = [[1], [6], [0, 1, 2, 3, 4, 5], [0, 1, 2, 4, 6]]
def f(self):
x = [self.LFSR[i] for i in self.f_IN]
return _sum(_prod(x[i] for i in j) for j in self.f_OUT)
def clock(self):
self.LFSR = self.LFSR[1:] + [_sum(self.LFSR[i] for i in self.TAP)]
return self.f()
class zer0lfsr:
def __init__(self):
self.key = [
random.getrandbits(64),
random.getrandbits(64),
random.getrandbits(64),
]
self.g1 = Generator1(n2l(self.key[0], 64))
self.g2 = Generator2(n2l(self.key[1], 64))
self.g3 = Generator3(n2l(self.key[2], 64))
def next(self):
o1 = self.g1.clock()
o2 = self.g2.clock()
o2 = self.g2.clock()
o3 = self.g3.clock()
o3 = self.g3.clock()
o3 = self.g3.clock()
o = (o1 * o2) ^^ (o2 * o3) ^^ (o1 * o3)
return o
N = 20 * 1000 * 8
def gen_chall(idx):
lfsr = zer0lfsr()
keys = [n2l(k, 64) for k in lfsr.key]
hint = keys[idx][:16]
z = [lfsr.next() for _ in range(N)]
return keys, z, hint
# solve Generator3
@lru_cache
def S(p, t):
if t == 1:
return p
return p * S(p, t - 1) + (1 - p) * (1 - S(p, t - 1))
@lru_cache
def C(n, m):
if n < m:
return 0
if m > n / 2:
m = n - m
res = 1
for i in range(m):
res *= n - i
for i in range(m):
res //= i + 1
return res
def gen_base_eqs(tap, length, m_step=1):
assert tap[0] == 0
assert [t % m_step == 0 for t in tap]
tap = [t // m_step for t in tap]
tap.sort()
t = len(tap)
eqs = [tap]
while True:
if (tap[-1] << 1) >= length:
break
tmp = [0] * len(tap)
for i in range(len(tap)):
tmp[i] = tap[i] << 1
eqs.append(tmp)
tap = tmp
return eqs
def gen_shift_eqs(loc, base_eq, length):
shift_eqs = []
for eq in base_eq:
for i in range(len(eq)):
offset = loc - eq[i]
if eq[0] + offset < 0 or eq[-1] + offset >= length:
continue
tmp = [0] * len(eq)
for j in range(len(eq)):
tmp[j] = eq[j] + offset
shift_eqs.append(tmp)
return shift_eqs
def calc_eq(shift_eqs, z, p, check_none=False):
if check_none:
new_shift_eqs = []
for eq in shift_eqs:
tmp = [z[i] is not None for i in eq]
if all(tmp):
new_shift_eqs.append(eq)
shift_eqs = new_shift_eqs
m = len(shift_eqs)
if m == 0:
return 0, 0, 0
t = len(shift_eqs[0])
h = 0
for eq in shift_eqs:
# print(eq)
xor_sum = 0
for i in eq:
xor_sum ^^= z[i]
if xor_sum == 0:
h += 1
s = S(p, t)
p1 = C(m, h) * pow(s, h) * pow(1 - s, m - h)
p0 = C(m, h) * pow(s, m - h) * pow(1 - s, h)
return m, h, p1 / (p1 + p0)
def check(lfsr, z, length, thres, m_step=1, verbose=False):
# lfsr = Generator3(key=initial_state)
z_test = [lfsr.clock() for _ in range(m_step*length)][m_step-1::m_step]
cnt = Counter([a == b for a, b in zip(z, z_test)])
if verbose:
print(cnt)
return cnt[True] >= thres
# Generator1
def z3_prod(L):
p = None
for x in L:
if p is None:
p = x
else:
p = And(p, x)
return p
def z3_sum(L):
s = None
for x in L:
if s is None:
s = x
else:
s = Xor(s, x)
return s
def g1(x):
x = x[:48].copy()
TAP2 = [
[2],
[5],
[9],
[15],
[22],
[26],
[39],
[26, 30],
[5, 9],
[15, 22, 26],
[15, 22, 39],
[9, 22, 26, 39],
]
return z3_sum(z3_prod(x[i] for i in j) for j in TAP2)
def h1(x):
h_IN = [2, 4, 7, 15, 27]
h_OUT = [[1], [3], [0, 3], [0, 1, 2], [0, 2, 3], [0, 2, 4], [0, 1, 2, 4]]
x_lfsr = x[48:].copy()
x_nfsr = x[:48].copy()
x = [x_lfsr[i] for i in h_IN[:-1]] + [x_nfsr[h_IN[-1]]]
return z3_sum(z3_prod(x[i] for i in j) for j in h_OUT)
def f1(x):
return z3_sum([x[0], h1(x)])
def clock1(x):
o = f1(x)
TAP = [0, 1, 12, 15]
x_lfsr = x[48:].copy()
x_nfsr = x[:48].copy()
x_nfsr = x_nfsr[1:] + [z3_sum([x_lfsr[0], g1(x)])]
x_lfsr = x_lfsr[1:] + [z3_sum(x_lfsr[i] for i in TAP)]
return x_nfsr + x_lfsr, o
# 事前準備
M = matrix(GF(2), 64, 64)
for i in range(63):
M[i, i+1] = 1
M[63, 0] = 1
M[63, 55] = 1
t = vector(GF(2), 64)
t[8] = 1
t[63] = 1
M3 = M ** 3
tmp = M ** 3
linear_eq_3 = []
for i in tqdm(range(N)):
res_list = (t * tmp).change_ring(ZZ).list()
linear_eq_3.append(sum([res * 2**j for j, res in enumerate(res_list)]))
tmp *= M3
M = matrix(GF(2), 48, 48)
for i in range(47):
M[i, i+1] = 1
M[47, 0] = 1
M[47, 35] = 1
t = vector(GF(2), 48)
t[10] = 1
t[20] = 1
t[47] = 1
linear_eq_2 = []
M2 = M ** 2
tmp = M ** 2
for i in tqdm(range(N)):
res_list = (t * tmp).change_ring(ZZ).list()
linear_eq_2.append(sum([res * 2**j for j, res in enumerate(res_list)]))
tmp *= M2
N = 160000
# tap の中身をすべて3の倍数にしたい (0は変えられないので)
base_tap = [0, 55, 64]
tap = []
while len(base_tap) != 0:
t = base_tap.pop()
if t % 3 == 0:
tap.append(t)
else:
base_tap.append(t + 55)
base_tap.append(t + 64)
tap.sort()
# [0, 165, 174, 174, 174, 183, 183, 183, 192]
tap = [0, 165, 174, 183, 192]
base_eqs = gen_base_eqs(tap, N, m_step=3)
shift_eqs_list_3 = []
for i in tqdm(range(N)):
shift_eqs = gen_shift_eqs(i, base_eqs, N)
shift_eqs_list_3.append(shift_eqs)
# tap の中身をすべて2の倍数にしたい (0は変えられないので)
tap = [0, 70, 96]
base_eqs = gen_base_eqs(tap, N, m_step=2)
shift_eqs_list_2 = []
for i in tqdm(range(N)):
shift_eqs = gen_shift_eqs(i, base_eqs, N)
shift_eqs_list_2.append(shift_eqs)
# ここから exploit
# Generator3
def solve_generator3(z):
p = float(0.75 * 31 / 32)
candidates = []
for i in tqdm(range(len(z))):
m, h, p_star = calc_eq(shift_eqs_list_3[i], z, p)
tmp = (p_star, i, m, h)
candidates.append(tmp)
candidates.sort(reverse=True)
print(candidates[:5])
locs = [(cand[1], z[cand[1]]) for cand in candidates]
# gen mat, target
assume = [(linear_eq_3[x[0]], x[1]) for x in locs]
mat = matrix(GF(2), 64, 64)
target = vector(GF(2), 64)
idx = 0
i = 0
while True:
for j in range(64):
mat[idx, j] = (assume[i][0] >> j) & 1
target[idx] = assume[i][1]
i += 1
if mat[:idx+1].rank() != idx + 1:
continue
else:
idx += 1
if mat.rank() == 64:
break
done = False
length = 200
thres = 130
while True:
# hamming = 0
tmp_key = mat.solve_right(target).change_ring(ZZ).list()
tmp_lfsr = Generator3(key=tmp_key)
if check(tmp_lfsr, z, length, thres, m_step=3, verbose=True):
done = True
break
# hamming = 1
for i in range(64):
target[i] = 1 - target[i]
tmp_key = mat.solve_right(target).change_ring(ZZ).list()
tmp_lfsr = Generator3(key=tmp_key)
target[i] = 1 - target[i]
if check(tmp_lfsr, z, length, thres, m_step=3, verbose=True):
done = True
break
if done:
break
# hamming = 2
for i, j in combinations(range(64), r=2):
target[i] = 1 - target[i]
target[j] = 1 - target[j]
tmp_key = mat.solve_right(target).change_ring(ZZ).list()
tmp_lfsr = Generator3(key=tmp_key)
target[i] = 1 - target[i]
target[j] = 1 - target[j]
if check(tmp_lfsr, z, length, thres, m_step=3, verbose=True):
done = True
break
if done:
break
else:
print("not found...")
break
key3_rec = tmp_key
print(key3_rec)
print(keys[2])
if done:
return key3_rec
else:
return None
# mat * vector(GF(2), keys[2]) + target
# lfsr2 = Generator2(key=keys[1])
# z_2_test = [lfsr2.clock() for _ in range(2*N)][1::2]
# Generator2
# TODO: Fast correlation attack をするときに「全部外れている」ものについて 1-a を考えたら解が求まる確率上がったりしないのかな
def solve_generator2(z, key3_rec, hint):
lfsr3 = Generator3(key=key3_rec)
z_3 = [lfsr3.clock() for _ in range(3*N)][2::3]
z_2 = []
for i in range(N):
if z[i] == 0 and z_3[i] == 1:
z_2.append(0)
elif z[i] == 1 and z_3[i] == 0:
z_2.append(1)
else:
z_2.append(None)
p = float(3/4 * 7/8)
# p = float(3/4 * 7/8 * 3/4)
candidates = []
for i in tqdm(range(len(z))):
m, h, p_star = calc_eq(shift_eqs_list_2[i], z_2, p, check_none=True)
tmp = (p_star, i, m, h)
candidates.append(tmp)
candidates.sort(reverse=True)
print(candidates[:5])
locs = [(cand[1], z_2[cand[1]]) for cand in candidates]
# gen mat, target
assume = [(linear_eq_2[x[0]], x[1]) for x in locs]
mat = matrix(GF(2), 48, 48)
target = vector(GF(2), 48)
idx = 0
i = 0
while True:
for j in range(48):
mat[idx, j] = (assume[i][0] >> j) & 1
target[idx] = assume[i][1]
i += 1
if mat[:idx+1].rank() != idx + 1:
continue
else:
idx += 1
if mat.rank() == 48:
break
# ker_basis = mat[:32].right_kernel().basis()
# kers = []
# for i in tqdm(range(2**16)):
# tmp = vector(GF(2), [0] * 48)
# for j in range(16):
# if i % 2 == 1:
# tmp += ker_basis[j]
# i //= 2
# kers.append(tmp)
done = False
length = 200
thres = 130
while True:
# hamming = 0
tmp_key = mat.solve_right(target).change_ring(ZZ).list()
tmp_key = [0] * 16 + tmp_key
tmp_lfsr = Generator2(key=tmp_key)
if check(tmp_lfsr, z, length, thres, m_step=2, verbose=True):
done = True
break
# hamming = 1
for i in range(48):
target[i] = 1 - target[i]
tmp_key = mat.solve_right(target).change_ring(ZZ).list()
tmp_key = [0] * 16 + tmp_key
tmp_lfsr = Generator2(key=tmp_key)
target[i] = 1 - target[i]
if check(tmp_lfsr, z, length, thres, m_step=2, verbose=True):
done = True
break
if done:
break
# hamming = 2
for i, j in combinations(range(48), r=2):
target[i] = 1 - target[i]
target[j] = 1 - target[j]
tmp_key = mat.solve_right(target).change_ring(ZZ).list()
tmp_key = [0] * 16 + tmp_key
tmp_lfsr = Generator2(key=tmp_key)
target[i] = 1 - target[i]
target[j] = 1 - target[j]
if check(tmp_lfsr, z, length, thres, m_step=2, verbose=True):
done = True
break
if done:
break
else:
print("not found...")
break
key2_lfsr_rec = tmp_key[16:]
print(key2_lfsr_rec)
print(keys[1][16:])
# mat * vector(GF(2), keys[1][16:]) + target
# key2_lfsr_rec_int = sum([2 ** i * j for i, j in enumerate(key2_lfsr_rec[::-1])])
# # find nfsr
# for n in tqdm(range(2**16)):
# tmp = 2**48 * n + key2_lfsr_rec_int
# tmp_key3 = n2l(KDF(tmp).expand(), 64)
# if tmp_key3 == key3_rec:
# print("found!", n)
# break
# key2_rec = n2l(n, 16) + key2_lfsr_rec
key2_rec = hint + key2_lfsr_rec
print(key2_rec)
print(keys[1])
if done:
return key2_rec
else:
return None
# Generator1
def solve_generator1(z, key3_rec, key2_rec):
lfsr3 = Generator3(key=key3_rec)
z_3 = [lfsr3.clock() for _ in range(3*N)][2::3]
lfsr2 = Generator2(key=key2_rec)
z_2 = [lfsr2.clock() for _ in range(2*N)][1::2]
z_1 = []
for i in range(N):
if z_2[i] == z_3[i]:
z_1.append(None)
else:
z_1.append(z[i])
# lfsr1_test = Generator1(key=keys[0])
# z_1_test = [lfsr1_test.clock() for _ in range(N)]
M = 256
state_lfsr = [Bool(f"x{i}") for i in range(M)]
state_nfsr = [Bool(f"y{i}") for i in range(M)]
s = Solver()
for i in range(M - 48 - 1):
state = state_nfsr[i:i+48] + state_lfsr[i:i+16]
next_state, o = clock1(state)
next_state_nfsr, next_state_lfsr = next_state[:48], next_state[48:]
for j in range(48):
s.add(next_state_nfsr[j] == state_nfsr[i+j+1])
for j in range(16):
s.add(next_state_lfsr[j] == state_lfsr[i+j+1])
if z_1[i] is None:
continue
else:
s.add(o == bool(z_1[i]))
# s.add(next_state[47] == state[i+48])
# s.add(next_state[63] == state[i+64])
s.check()
m = s.model()
# key1_rec = [int(bool(m.eval(y))) for y in state_nfsr[:48]] + [int(bool(m.eval(x))) for x in state_lfsr[:16]]
key1_rec = []
for i in range(48):
try:
key1_rec.append(int(bool(m.eval(state_nfsr[i]))))
except Z3Exception:
key1_rec.append(None)
for i in range(16):
try:
key1_rec.append(int(bool(m.eval(state_lfsr[i]))))
except Z3Exception:
key1_rec.append(None)
print(key1_rec)
print(keys[0])
return key1_rec
# ここで解く
i = 0
while True:
i += 1
N = 160000
keys, z, hint = gen_chall(idx=1)
now = time.time()
key3_rec = solve_generator3(z)
if key3_rec is None:
continue
key2_rec = solve_generator2(z, key3_rec, hint)
if key2_rec is None:
continue
break
key1_rec = solve_generator1(z, key3_rec, key2_rec)
print(i, time.time() - now)
#!/usr/bin/env python3
import random
import signal
import socketserver
import string
from hashlib import sha256
from os import urandom
from secret import flag
flag_plus = b'flag{?????????????}'
def _prod(L):
p = 1
for x in L:
p *= x
return p
def _sum(L):
s = 0
for x in L:
s ^= x
return s
def b2n(x):
return int.from_bytes(x, 'big')
def n2l(x, l):
return list(map(int, '{{0:0{}b}}'.format(l).format(x)))
def split(x, n, l):
return [(x >> (i * l)) % 2**l for i in range(n)][::-1]
def combine(x, n, l):
return sum([x[i] << (l * (n - i - 1)) for i in range(n)])
class Generator1:
def __init__(self, key: list):
assert len(key) == 64
self.NFSR = key[: 48]
self.LFSR = key[48: ]
self.TAP = [0, 1, 12, 15]
self.TAP2 = [[2], [5], [9], [15], [22], [26], [39], [26, 30], [5, 9], [15, 22, 26], [15, 22, 39], [9, 22, 26, 39]]
self.h_IN = [2, 4, 7, 15, 27]
self.h_OUT = [[1], [3], [0, 3], [0, 1, 2], [0, 2, 3], [0, 2, 4], [0, 1, 2, 4]]
def g(self):
x = self.NFSR
return _sum(_prod(x[i] for i in j) for j in self.TAP2)
def h(self):
x = [self.LFSR[i] for i in self.h_IN[:-1]] + [self.NFSR[self.h_IN[-1]]]
return _sum(_prod(x[i] for i in j) for j in self.h_OUT)
def f(self):
return _sum([self.NFSR[0], self.h()])
def clock(self):
o = self.f()
self.NFSR = self.NFSR[1: ] + [self.LFSR[0] ^ self.g()]
self.LFSR = self.LFSR[1: ] + [_sum(self.LFSR[i] for i in self.TAP)]
return o
class Generator2:
def __init__(self, key):
assert len(key) == 64
self.NFSR = key[: 16]
self.LFSR = key[16: ]
self.TAP = [0, 35]
self.f_IN = [0, 10, 20, 30, 40, 47]
self.f_OUT = [[0, 1, 2, 3], [0, 1, 2, 4, 5], [0, 1, 2, 5], [0, 1, 2], [0, 1, 3, 4, 5], [0, 1, 3, 5], [0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 2, 3, 4, 5], [
0, 2, 3], [0, 3, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4], [1, 2, 3, 5], [1, 2], [1, 3, 5], [1, 3], [1, 4], [1], [2, 4, 5], [2, 4], [2], [3, 4], [4, 5], [4], [5]]
self.TAP2 = [[0, 3, 7], [1, 11, 13, 15], [2, 9]]
self.h_IN = [0, 2, 4, 6, 8, 13, 14]
self.h_OUT = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 4, 6], [1, 3, 4]]
def f(self):
x = [self.LFSR[i] for i in self.f_IN]
return _sum(_prod(x[i] for i in j) for j in self.f_OUT)
def h(self):
x = [self.NFSR[i] for i in self.h_IN]
return _sum(_prod(x[i] for i in j) for j in self.h_OUT)
def g(self):
x = self.NFSR
return _sum(_prod(x[i] for i in j) for j in self.TAP2)
def clock(self):
self.LFSR = self.LFSR[1: ] + [_sum(self.LFSR[i] for i in self.TAP)]
self.NFSR = self.NFSR[1: ] + [self.LFSR[1] ^ self.g()]
return self.f() ^ self.h()
class Generator3:
def __init__(self, key: list):
assert len(key) == 64
self.LFSR = key
self.TAP = [0, 55]
self.f_IN = [0, 8, 16, 24, 32, 40, 63]
self.f_OUT = [[1], [6], [0, 1, 2, 3, 4, 5], [0, 1, 2, 4, 6]]
def f(self):
x = [self.LFSR[i] for i in self.f_IN]
return _sum(_prod(x[i] for i in j) for j in self.f_OUT)
def clock(self):
self.LFSR = self.LFSR[1: ] + [_sum(self.LFSR[i] for i in self.TAP)]
return self.f()
class zer0lfsr:
def __init__(self):
self.key = [random.getrandbits(64), random.getrandbits(64), random.getrandbits(64)]
self.g1 = Generator1(n2l(self.key[0], 64))
self.g2 = Generator2(n2l(self.key[1], 64))
self.g3 = Generator3(n2l(self.key[2], 64))
def next(self):
o1 = self.g1.clock()
o2 = self.g2.clock()
o2 = self.g2.clock()
o3 = self.g3.clock()
o3 = self.g3.clock()
o3 = self.g3.clock()
o = (o1 * o2) ^ (o2 * o3) ^ (o1 * o3)
return o
class Task(socketserver.BaseRequestHandler):
def __init__(self, *args, **kargs):
super().__init__(*args, **kargs)
def proof_of_work(self):
random.seed(urandom(8))
proof = ''.join([random.choice(string.ascii_letters + string.digits + '!#$%&*-?') for _ in range(20)])
digest = sha256(proof.encode()).hexdigest()
self.dosend('sha256(XXXX + {}) == {}'.format(proof[4: ], digest))
self.dosend('Give me XXXX:')
x = self.request.recv(10)
x = (x.strip()).decode('utf-8')
if len(x) != 4 or sha256((x + proof[4: ]).encode()).hexdigest() != digest:
return False
return True
def dosend(self, msg):
try:
self.request.sendall(msg.encode('latin-1') + b'\n')
except:
pass
def timeout_handler(self, signum, frame):
raise TimeoutError
def handle(self):
try:
signal.signal(signal.SIGALRM, self.timeout_handler)
signal.alarm(5)
self.dosend('Input the flag of zer0lfsr+: ')
guess = self.request.recv(100).strip()
assert guess == flag_plus
signal.alarm(50)
if not self.proof_of_work():
self.dosend('You must pass the PoW!')
return
lfsr = zer0lfsr()
for i in range(20):
keystream = ''
for j in range(1000):
b = 0
for k in range(8):
b = (b << 1) + lfsr.next()
keystream += chr(b)
self.dosend('start:::' + keystream + ':::end')
signal.alarm(180)
self.dosend('hint: ')
idx = int(self.request.recv(10).strip())
assert idx in [0, 1, 2]
self.dosend(str(lfsr.key[idx] >> 48))
self.dosend('k1: ')
k1 = int(self.request.recv(100).strip())
self.dosend('k2: ')
k2 = int(self.request.recv(100).strip())
self.dosend('k3: ')
k3 = int(self.request.recv(100).strip())
if lfsr.key == [k1, k2, k3]:
self.dosend(flag)
else:
self.dosend('Wrong ;(')
except TimeoutError:
self.dosend('Timeout!')
self.request.close()
except:
self.dosend('Wtf?')
self.request.close()
class ThreadedServer(socketserver.ForkingMixIn, socketserver.TCPServer):
pass
if __name__ == "__main__":
HOST, PORT = '0.0.0.0', 13338
server = ThreadedServer((HOST, PORT), Task)
server.allow_reuse_address = True
server.serve_forever()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment