Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Fast Correlation Attacks - Algorithm A
BLOCK = 48
class LFSR:
def __init__(self, init, mask, length=BLOCK):
self.init = init
self.length = length
self.lengthmask = 2 ** length - 1
self.mask = mask & self.lengthmask
def next(self):
nextdata = (self.init << 1) & self.lengthmask
output = parity(self.init & self.mask)
nextdata ^= output
self.init = nextdata
return output
def step_back(self):
output = self.init & 1
predata = self.init >> 1
high_bit = parity(predata & self.mask) ^ output
self.init = (high_bit << (self.length - 1)) | predata
def parity(x):
res = 0
while x:
x -= x & (-x)
res ^= 1
return res
def get_data():
data = open("keystream", "rb").read().decode()
result = []
for i in range(len(data)):
x = ord(data[i])
for j in range(8):
result.append((x >> (7 - j)) & 1)
return tuple(result)
def bit_stream_to_int(a):
return int(''.join(map(str, a)), 2)
def S(p, t):
if t == 1:
return p
return p * S(p, t - 1) + (1 - p) * (1 - S(p, t - 1))
def C(n, m):
if n < m:
return 0
if m > n / 2:
m = n - m
res = 1
for i in range(m):
res *= n - i
for i in range(m):
res //= i + 1
return res
def calc_eq(loc, mask, z, n, p):
l = len(z)
t = 0
tap = [n]
for i in range(n):
if (mask >> i) & 1:
tap.append(n - i - 1)
t += 1
tap.reverse()
eqs = [tap]
while True:
if (tap[-1] << 1) >= l:
break
tmp = [0] * len(tap)
for i in range(len(tap)):
tmp[i] = tap[i] << 1
eqs.append(tmp)
tap = tmp
shift_eqs = []
for eq in eqs:
for i in range(len(eq)):
offset = loc - eq[i]
if eq[0] + offset < 0 or eq[-1] + offset >= l:
continue
tmp = [0] * len(eq)
for j in range(len(eq)):
tmp[j] = eq[j] + offset
shift_eqs.append(tmp)
m = len(shift_eqs)
if m == 0:
return 0, 0, 0
h = 0
for eq in shift_eqs:
# print(eq)
xor_sum = 0
for i in eq:
xor_sum ^= z[i]
if xor_sum == 0:
h += 1
s = S(p, t)
p1 = C(m, h) * pow(s, h) * pow(1 - s, m - h)
p0 = C(m, h) * pow(s, m - h) * pow(1 - s, h)
return m, h, p1 / (p1 + p0)
def gen_linear_eq(mask, n, length):
length = max(length, n)
tap = []
for i in range(n):
if (mask >> i) & 1:
tap.append(i + 1)
eqs = []
for i in range(n):
eqs.append(1 << i)
for i in range(n, length):
res = 0
for j in tap:
res ^= eqs[i - j]
eqs.append(res)
return eqs
def solve(assume, n):
eq_len = len(assume)
mat = []
for i in range(eq_len):
mat.append([0] * n)
b = [0] * eq_len
for i in range(eq_len):
b[i] = assume[i][1]
for j in range(n):
mat[i][j] = (assume[i][0] >> j) & 1
for i in range(n):
tmp = -1
for j in range(i, eq_len):
if mat[j][i]:
tmp = j
break
if tmp == -1:
return []
mat[tmp], mat[i] = mat[i], mat[tmp]
b[tmp], b[i] = b[i], b[tmp]
for j in range(eq_len):
if not mat[j][i] or i == j:
continue
b[j] ^= b[i]
for k in range(i, n):
mat[j][k] ^= mat[i][k]
if not any(mat[n - 1]):
return []
print(b[:n])
return b[:n]
def get_init_stat(locs, linear_eq, mask, n, z):
assume = [(linear_eq[x[0]], x[1]) for x in locs]
b = []
idx = n
print("----- try solve equations -----")
while not b: # try again if linear correlation
b = solve(assume[:idx], n)
idx += 1
print("----- solve success -----")
stat = bit_stream_to_int(b)
print("----- genrate original LFSR -----")
l = LFSR(stat, mask, n)
for i in range(n):
l.step_back()
init_stat = l.init
print("init:", init_stat)
print("----- genrate original LFSR finished -----")
z_new = []
same_cnt = 0
for i in range(len(z)):
z_new.append(l.next())
same_cnt += int(z[i] == z_new[i])
for loc in locs[:idx]:
assert(z_new[loc[0]] == loc[1])
print("match rate:", same_cnt / len(z))
return init_stat
def crack(mask, n, z, p):
print("----- select candidates -----")
candidates = []
for i in range(len(z)):
m, h, p_star = calc_eq(i, mask, z, n, p)
tmp = (p_star, i, m, h)
candidates.append(tmp)
candidates.sort(reverse=True)
print(candidates[:5])
print("----- select candidates finished -----")
linear_eq = gen_linear_eq(mask, BLOCK, len(z))
locs = [(cand[1], z[cand[1]]) for cand in candidates]
return get_init_stat(locs, linear_eq, mask, n, z)
if __name__ == "__main__":
z = get_data()
mask = 0b100000000000000000000000010000000000000000000000
init1 = crack(mask, BLOCK, z, 0.75)
mask = 0b100000000000000000000000000000000010000000000000
init2 = crack(mask, BLOCK, z, 0.75)
mask = 0b100000100000000000000000000000000000000000000000
init3 = crack(mask, BLOCK, z, 0.75)
init = [init1, init2, init3]
print(init)
for i in range(len(init)):
init[i] = bytes.fromhex(hex(init[i])[2:])
init_bytes = b""
for i in init:
init_bytes += i
import hashlib
print("flag{" + hashlib.sha256(init_bytes).hexdigest() + "}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment