|
from collections import defaultdict, Counter |
|
from subprocess import Popen, PIPE |
|
|
|
LOCAL = False |
|
# FLAGS = 0, 0, 0, 1 |
|
# LOCAL = False |
|
FLAGS = 1, 1, 1, 1 |
|
|
|
ROUNDS = 8 |
|
|
|
IP = [58, 50, 42, 34, 26, 18, 10, 2, 60, 52, 44, 36, 28, 20, 12, 4, 62, 54, 46, 38, 30, 22, 14, 6, 64, 56, 48, 40, 32, 24, 16, 8, 57, 49, 41, 33, 25, 17, 9, 1, 59, 51, 43, 35, 27, 19, 11, 3, 61, 53, 45, 37, 29, 21, 13, 5, 63, 55, 47, 39, 31, 23, 15, 7] |
|
|
|
IP_1 = [40, 8, 48, 16, 56, 24, 64, 32, 39, 7, 47, 15, 55, 23, 63, 31, 38, 6, 46, 14, 54, 22, 62, 30, 37, 5, 45, 13, 53, 21, 61, 29, 36, 4, 44, 12, 52, 20, 60, 28, 35, 3, 43, 11, 51, 19, 59, 27, 34, 2, 42, 10, 50, 18, 58, 26, 33, 1, 41, 9, 49, 17, 57, 25] |
|
|
|
IP_inv = [IP.index(i)+1 for i in xrange(1, 65)] |
|
IP_1_inv = [IP_1.index(i)+1 for i in xrange(1, 65)] |
|
|
|
E = [32, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 8, 9, 10, 11, 12, 13, 12, 13, 14, 15, 16, 17, 16, 17, 18, 19, 20, 21, 20, 21, 22, 23, 24, 25, 24, 25, 26, 27, 28, 29, 28, 29, 30, 31, 32, 1] |
|
|
|
SBOX = [[[14, 4, 13, 1, 2, 15, 11, 8, 3, 10, 6, 12, 5, 9, 0, 7], [0, 15, 7, 4, 14, 2, 13, 1, 10, 6, 12, 11, 9, 5, 3, 8], [4, 1, 14, 8, 13, 6, 2, 11, 15, 12, 9, 7, 3, 10, 5, 0], [15, 12, 8, 2, 4, 9, 1, 7, 5, 11, 3, 14, 10, 0, 6, 13]], [[15, 1, 8, 14, 6, 11, 3, 4, 9, 7, 2, 13, 12, 0, 5, 10], [3, 13, 4, 7, 15, 2, 8, 14, 12, 0, 1, 10, 6, 9, 11, 5], [0, 14, 7, 11, 10, 4, 13, 1, 5, 8, 12, 6, 9, 3, 2, 15], [13, 8, 10, 1, 3, 15, 4, 2, 11, 6, 7, 12, 0, 5, 14, 9]], [[10, 0, 9, 14, 6, 3, 15, 5, 1, 13, 12, 7, 11, 4, 2, 8], [13, 7, 0, 9, 3, 4, 6, 10, 2, 8, 5, 14, 12, 11, 15, 1], [13, 6, 4, 9, 8, 15, 3, 0, 11, 1, 2, 12, 5, 10, 14, 7], [1, 10, 13, 0, 6, 9, 8, 7, 4, 15, 14, 3, 11, 5, 2, 12]], [[7, 13, 14, 3, 0, 6, 9, 10, 1, 2, 8, 5, 11, 12, 4, 15], [13, 8, 11, 5, 6, 15, 0, 3, 4, 7, 2, 12, 1, 10, 14, 9], [10, 6, 9, 0, 12, 11, 7, 13, 15, 1, 3, 14, 5, 2, 8, 4], [3, 15, 0, 6, 10, 1, 13, 8, 9, 4, 5, 11, 12, 7, 2, 14]], [[2, 12, 4, 1, 7, 10, 11, 6, 8, 5, 3, 15, 13, 0, 14, 9], [14, 11, 2, 12, 4, 7, 13, 1, 5, 0, 15, 10, 3, 9, 8, 6], [4, 2, 1, 11, 10, 13, 7, 8, 15, 9, 12, 5, 6, 3, 0, 14], [11, 8, 12, 7, 1, 14, 2, 13, 6, 15, 0, 9, 10, 4, 5, 3]], [[12, 1, 10, 15, 9, 2, 6, 8, 0, 13, 3, 4, 14, 7, 5, 11], [10, 15, 4, 2, 7, 12, 9, 5, 6, 1, 13, 14, 0, 11, 3, 8], [9, 14, 15, 5, 2, 8, 12, 3, 7, 0, 4, 10, 1, 13, 11, 6], [4, 3, 2, 12, 9, 5, 15, 10, 11, 14, 1, 7, 6, 0, 8, 13]], [[4, 11, 2, 14, 15, 0, 8, 13, 3, 12, 9, 7, 5, 10, 6, 1], [13, 0, 11, 7, 4, 9, 1, 10, 14, 3, 5, 12, 2, 15, 8, 6], [1, 4, 11, 13, 12, 3, 7, 14, 10, 15, 6, 8, 0, 5, 9, 2], [6, 11, 13, 8, 1, 4, 10, 7, 9, 5, 0, 15, 14, 2, 3, 12]], [[13, 2, 8, 4, 6, 15, 11, 1, 10, 9, 3, 14, 5, 0, 12, 7], [1, 15, 13, 8, 10, 3, 7, 4, 12, 5, 6, 11, 0, 14, 9, 2], [7, 11, 4, 1, 9, 12, 14, 2, 0, 6, 10, 13, 15, 3, 5, 8], [2, 1, 14, 7, 4, 10, 8, 13, 15, 12, 9, 0, 3, 5, 6, 11]]] |
|
|
|
P = [16, 7, 20, 21, 29, 12, 28, 17, 1, 15, 23, 26, 5, 18, 31, 10, 2, 8, 24, 14, 32, 27, 3, 9, 19, 13, 30, 6, 22, 11, 4, 25] |
|
P_inv = [P.index(i)+1 for i in xrange(1, 33)] |
|
|
|
PC_1 = [57, 49, 41, 33, 25, 17, 9, 1, 58, 50, 42, 34, 26, 18, 10, 2, 59, 51, 43, 35, 27, 19, 11, 3, 60, 52, 44, 36, 63, 55, 47, 39, 31, 23, 15, 7, 62, 54, 46, 38, 30, 22, 14, 6, 61, 53, 45, 37, 29, 21, 13, 5, 28, 20, 12, 4] |
|
|
|
PC_2 = [14, 17, 11, 24, 1, 5, 3, 28, 15, 6, 21, 10, 23, 19, 12, 4, 26, 8, 16, 7, 27, 20, 13, 2, 41, 52, 31, 37, 47, 55, 30, 40, 51, 45, 33, 48, 44, 49, 39, 56, 34, 53, 46, 42, 50, 36, 29, 32] |
|
|
|
R = [1,1,2,2,2,2,2,2,1,2,2,2,2,2,2,1] |
|
|
|
def chr_to_bits(c): |
|
res = bin(ord(c))[2:] |
|
return map(int, list(res.rjust(8,'0'))) |
|
|
|
def str_to_bits(s): |
|
res = [] |
|
for c in s: |
|
res.extend(chr_to_bits(c)) |
|
return res |
|
|
|
def bits_to_chr(bits): |
|
res = int(''.join(map(str, bits)), 2) |
|
return chr(res) |
|
|
|
def bits_to_str(bits): |
|
res = '' |
|
for i in range(0, len(bits), 8): |
|
res += bits_to_chr(bits[i:i+8]) |
|
return res |
|
|
|
def xor_bits(l,r): |
|
return [a ^ b for a, b in zip(l, r)] |
|
# return map(lambda (x,y):x^y, zip(l,r)) |
|
|
|
def parity_mask(word, mask): |
|
return sum(a & b for a, b in zip(word, mask)) & 1 |
|
|
|
def F(hblk, subkey): |
|
bits = [hblk[x-1] for x in E] |
|
bits = xor_bits(bits, subkey) |
|
res = [] |
|
for i in range(0, len(bits), 6): |
|
# row = bits[i]*2+bits[i+5] |
|
# col = bits[i+1]*8+bits[i+2]*4+bits[i+3]*2+bits[i+4] |
|
# val = bin(SBOX[i/6][row][col])[2:] |
|
# res.extend(map(int, list(val.rjust(4,'0')))) |
|
res.extend(MY_SBOX_BITS[i/6][tuple(bits[i:i+6])]) |
|
res = [res[x-1] for x in P] |
|
return res |
|
|
|
def encrypt_block(blk, subkeys): |
|
assert len(blk)==8 |
|
bits = str_to_bits(blk) |
|
bits = [bits[x-1] for x in IP] |
|
for i in range(ROUNDS): |
|
left = bits[:32] |
|
right = bits[32:] |
|
left = xor_bits(left, F(right, subkeys[i])) |
|
bits = right + left |
|
bits = left + right |
|
bits = [bits[x-1] for x in IP_1] |
|
return bits_to_str(bits) |
|
|
|
def gen_subkey(key): |
|
kbits = str_to_bits(key) |
|
kbits = [kbits[x-1] for x in PC_1] |
|
left = kbits[:28] |
|
right = kbits[28:] |
|
subkeys = [] |
|
for i in range(ROUNDS): |
|
left = left[R[i]:]+left[:R[i]] |
|
right = right[R[i]:]+right[:R[i]] |
|
cur = left+right |
|
subkeys.append([cur[x-1] for x in PC_2]) |
|
# if subkeys[0] == subkeys[1] or subkeys[0] == subkeys[2]: |
|
# raise Exception("Boom") |
|
return subkeys |
|
|
|
def gen_subkey_dbg(): |
|
kbits = range(64) |
|
kbits = [kbits[x-1] for x in PC_1] |
|
left = kbits[:28] |
|
right = kbits[28:] |
|
subkeys = [] |
|
for i in range(ROUNDS): |
|
left = left[R[i]:]+left[:R[i]] |
|
right = right[R[i]:]+right[:R[i]] |
|
cur = left+right |
|
subkeys.append([cur[x-1] for x in PC_2]) |
|
# if subkeys[0] == subkeys[1] or subkeys[0] == subkeys[2]: |
|
# raise Exception("Boom") |
|
return subkeys |
|
|
|
def encrypt(pt, key): |
|
assert len(pt)%8==0 |
|
subkeys = gen_subkey(key) |
|
ct = '' |
|
for i in range(0, len(pt), 8): |
|
ct += encrypt_block(pt[i:i+8], subkeys) |
|
return ct |
|
|
|
encrypt_local = encrypt |
|
|
|
def decrypt_block(blk, subkeys): |
|
assert len(blk)==8 |
|
bits = str_to_bits(blk) |
|
bits = [bits[x-1] for x in IP] |
|
for i in range(ROUNDS): |
|
left = bits[:32] |
|
right = bits[32:] |
|
left = xor_bits(left, F(right, subkeys[ROUNDS-1-i])) |
|
bits = right + left |
|
bits = left + right |
|
bits = [bits[x-1] for x in IP_1] |
|
return bits_to_str(bits) |
|
|
|
def decrypt(ct, key): |
|
assert len(ct)%8==0 |
|
subkeys = gen_subkey(key) |
|
pt = '' |
|
for i in range(0, len(ct), 8): |
|
pt += decrypt_block(ct[i:i+8], subkeys) |
|
return pt |
|
|
|
from itertools import product |
|
|
|
MY_SBOX_BITS = [] |
|
for si, s in enumerate(SBOX): |
|
new_s = {} |
|
for num, bits in enumerate(product(range(2), repeat=6)): |
|
i = 0 |
|
row = bits[i]*2+bits[i+5] |
|
col = bits[i+1]*8+bits[i+2]*4+bits[i+3]*2+bits[i+4] |
|
val = bin(SBOX[si][row][col])[2:] |
|
res = map(int, list(val.rjust(4,'0'))) |
|
new_s[tuple(bits)] = res |
|
MY_SBOX_BITS.append(new_s) |
|
|
|
# ROUNDS = 16 |
|
try: |
|
from Crypto.Cipher import DES |
|
c = DES.new("omgwtfxz") |
|
except: |
|
pass |
|
else: |
|
ROUNDS = 16 |
|
assert encrypt("abcdfuck", "omgwtfxz").encode("hex") == c.encrypt("abcdfuck").encode("hex") |
|
print "IMPL OK!" |
|
# quit() |
|
ROUNDS = 8 |
|
ROUNDS = 8 |
|
|
|
from random import * |
|
def xor(a, b): |
|
return "".join([chr(ord(a[i]) ^ ord(b[i % len(b)])) for i in xrange(len(a))]) |
|
|
|
|
|
|
|
|
|
def Fmy(hblk, sk, si): |
|
bits = [hblk[x-1] for x in E[si*6:si*6+6]] |
|
bits = xor_bits(bits, sk) |
|
|
|
# i = 0 |
|
# row = bits[i]*2+bits[i+5] |
|
# col = bits[i+1]*8+bits[i+2]*4+bits[i+3]*2+bits[i+4] |
|
# val = bin(SBOX[si][row][col])[2:] |
|
# res = map(int, list(val.rjust(4,'0'))) |
|
res = MY_SBOX_BITS[si][tuple(bits)] |
|
|
|
word = [0] * 32 |
|
word[si*4:si*4+4] = res |
|
word = [word[x-1] for x in P] |
|
return word |
|
|
|
def str_perm(s, p): |
|
s = str_to_bits(s) |
|
s = [s[x-1] for x in p] |
|
s = bits_to_str(s) |
|
return s |
|
|
|
# d = str_perm("\x00\x80\x82\x00", E) |
|
# d = str_perm("\x00\x80\x82\x00", E) |
|
# d = str_perm("\x00\x00\x02\x02", E) |
|
# d = str_perm("\x00\x00\x80\x00", E) |
|
# d = str_perm("\x60\x00\x00\x00", E) |
|
# d = str_to_bits(d) |
|
# for i in xrange(0, len(d), 6): |
|
# print d[i:i+6] |
|
# quit() |
|
# d = [0] * 32 |
|
# d[5*4:5*4+4] = [1] * 4 |
|
# d = [d[x-1] for x in P] |
|
# print bits_to_str(d).encode("hex") |
|
# d = [0] * 32 |
|
# d[7*4:7*4+4] = [1] * 4 |
|
# d = [d[x-1] for x in P] |
|
# print bits_to_str(d).encode("hex") |
|
# quit() |
|
# 10202008 |
|
# 08020820 |
|
|
|
if LOCAL: |
|
KEY = "omgwtfxz" |
|
KEY = "\x97\xba\x95\xbd\x1b\xc5\xac\xd1" |
|
KEY = "".join(chr(ord(c)^(ord(c)&1)) for c in KEY) |
|
|
|
SK = gen_subkey(KEY) |
|
else: |
|
KEY = "omgwtfxz" |
|
SK = gen_subkey(KEY) |
|
|
|
# pipes... |
|
f1 = open("p1", "wb") |
|
f2 = open("p2", "rb") |
|
|
|
# f = Sock("111.186.56.54 10001") |
|
# f.send_line("C(FL") |
|
BLOCK = 2**16 |
|
# BLOCK = (20000-8)/2 |
|
# while BLOCK % 8: |
|
# BLOCK -= 1 |
|
from struct import pack |
|
|
|
def encrypt(pt, k): |
|
print "QUERY", len(pt) |
|
fullct = "" |
|
for i in xrange(0, len(pt), BLOCK): |
|
# f.read_until("plaintext(hex): ") |
|
# f.send_line(pt[i:i+BLOCK].encode("hex")) |
|
block = pt[i:i+BLOCK] |
|
print "query chunk", i, len(block) |
|
f1.write(pack("<Q", len(block)) + block) |
|
f1.flush() |
|
ct = b"" |
|
while len(ct) < len(block): |
|
ct += f2.read(len(block) - len(ct)) |
|
# ct = f.read_line().strip() |
|
# ct = ct.decode("hex") |
|
# assert len(pt[i:i+BLOCK]) == len(ct) |
|
fullct += ct |
|
print "GOT" |
|
return fullct |
|
|
|
# pt = "6f6d67777466787a".decode("hex") |
|
# ct = "d65100cf23f0106d".decode("hex") |
|
# key = "d6fafeba9e9cb2b2".decode("hex") |
|
# print encrypt(pt, key) == ct |
|
# print encrypt(pt, key).encode("hex") |
|
# quit |
|
|
|
|
|
# pt = "omgwtfxz" |
|
# # KEY = "\x00" * 8 |
|
# ROUNDS = 8 |
|
# ct = encrypt(pt, KEY) |
|
# print "pt = %s" % pt.encode("hex") |
|
# print "ct = %s" % ct.encode("hex") |
|
# print "key= %s" % KEY.encode("hex") |
|
# quit() |
|
|
|
pairs = [] |
|
PT_DELTA_P = "0080820060000000".decode("hex") |
|
PT_DELTA_Pinv = str_perm(PT_DELTA_P, IP_inv) |
|
|
|
RANGE6 = list(product(range(2), repeat=6)) |
|
|
|
LMASK = str_to_bits("21040080".decode("hex")) |
|
|
|
def get_SK(i, j): |
|
return SK[i][j*6:j*6+6] |
|
def frombin(v): |
|
return int("".join(map(str, v)), 2) |
|
|
|
|
|
SUBKEY_POS = gen_subkey_dbg() |
|
|
|
SK0_1 = get_SK(0, 0) |
|
SK0_1_int = ord(bits_to_chr(SK0_1)) |
|
SK7_1 = get_SK(7, 0) |
|
SK7_1_int = ord(bits_to_chr(SK7_1)) |
|
SK7_5 = get_SK(7, 4) |
|
SK7_5_int = ord(bits_to_chr(SK7_5)) |
|
|
|
SK0_2 = get_SK(0, 2-1) |
|
SK0_2_int = ord(bits_to_chr(SK0_2)) |
|
SK0_3 = get_SK(0, 3-1) |
|
SK0_3_int = ord(bits_to_chr(SK0_3)) |
|
SK0_4 = get_SK(0, 4-1) |
|
SK0_4_int = ord(bits_to_chr(SK0_4)) |
|
SK0_5 = get_SK(0, 5-1) |
|
SK0_5_int = ord(bits_to_chr(SK0_5)) |
|
SK0_6 = get_SK(0, 6-1) |
|
SK0_6_int = ord(bits_to_chr(SK0_6)) |
|
SK0_8 = get_SK(0, 8-1) |
|
SK0_8_int = ord(bits_to_chr(SK0_8)) |
|
|
|
print "sk0_1:", SK0_1_int, SK0_1 |
|
print "sk7_1:", SK7_1_int, SK7_1 |
|
print "sk7_5:", SK7_5_int, SK7_5 |
|
print "sk0_2:", SK0_2_int, SK0_2 |
|
print "sk0_3:", SK0_3_int, SK0_3 |
|
print "sk0_4:", SK0_4_int, SK0_4 |
|
print "sk0_5:", SK0_5_int, SK0_5 |
|
print "sk0_6:", SK0_6_int, SK0_6 |
|
print "sk0_8:", SK0_8_int, SK0_8 |
|
print |
|
|
|
if 0: |
|
diff = str_to_bits("\x40\x00\x00\x00") |
|
k = [randint(0, 1) for _ in xrange(48)] |
|
bits1 = [randint(0, 1) for _ in xrange(32)] |
|
bits2 = xor_bits(bits1, diff) |
|
out1 = F(bits1, k) |
|
out2 = F(bits2, k) |
|
delta = xor_bits(out1, out2) |
|
print bits_to_str(delta).encode("hex") |
|
|
|
delta = [delta[x-1] for x in E] |
|
bits = [0] * 32 |
|
for i in xrange(0, len(delta), 6): |
|
print i/6+1, delta[i:i+6], |
|
if sum(delta[i:i+6]): |
|
print "*" |
|
else: |
|
print |
|
if sum(delta[i:i+6]) and i/6+1 not in (1,6,8): |
|
print i/6+1 |
|
bits[i/6*4:i/6*4+4] = [1] * 4 |
|
bits = [bits[x-1] for x in P] |
|
print "submask: 0x%s" % bits_to_str(bits).encode("hex") |
|
# quit() |
|
|
|
# submasks = [] |
|
# i = mask = 0x18222828 |
|
# while i >= 0: |
|
# submasks.append(str_to_bits(("%08x" % i).decode("hex")) + [0] * 32) |
|
# # print "%08x" % i |
|
# if i == 0: |
|
# break |
|
# i = (i - 1) & mask |
|
#00800202 |
|
#00008202 |
|
#00808000 |
|
#00800002 |
|
#00008002 |
|
#00008200 |
|
|
|
# print bits_to_str(xor_bits(out1, out2)).encode("hex") |
|
# quit() |
|
|
|
|
|
|
|
# for i, sk in enumerate(SUBKEY_POS): |
|
# print i, sk, len(sk) |
|
|
|
if FLAGS[0]: |
|
print("STAGE 1") |
|
ROUNDS = 8 |
|
|
|
N = 2**14 |
|
|
|
pts = [] |
|
for i in xrange(N): |
|
pt1 = "".join(chr(randint(0, 255)) for _ in xrange(8)) |
|
pt2 = xor(pt1, PT_DELTA_P) |
|
|
|
pt1 = str_perm(pt1, IP_inv) |
|
pt2 = str_perm(pt2, IP_inv) |
|
pts.append(pt1) |
|
pts.append(pt2) |
|
|
|
cts = [] |
|
ct = encrypt("".join(pts), KEY) |
|
for i in xrange(0, len(ct), 8): |
|
cts.append(ct[i:i+8]) |
|
|
|
pairs = zip(pts[::2], pts[1::2], cts[::2], cts[1::2]) |
|
|
|
counters = [0] * 2**6 |
|
for i, (pt1, pt2, ct1, ct2) in enumerate(pairs): |
|
if i % 2**12 == 0: |
|
print hex(i) |
|
|
|
ps = [0] * 2**6 |
|
for ct in (ct1, ct2): |
|
bits = str_to_bits(ct) |
|
bits = [bits[x-1] for x in IP_1_inv] |
|
left = bits[:32] |
|
right = bits[32:] |
|
|
|
parity = parity_mask(LMASK, right) ^ left[16] |
|
for skn, sk in enumerate(RANGE6): |
|
func = Fmy(right, sk, si=0) |
|
ps[skn] ^= func[16] ^ parity |
|
|
|
for j in xrange(64): |
|
counters[j] += 1 - ps[j] |
|
|
|
best = sorted(enumerate(counters), key=lambda (a, b): b) |
|
best_cnt = best[-1][1] |
|
print "best", best_cnt |
|
cands = [] |
|
for a, b in best[-10:]: |
|
print a, b |
|
if b == best_cnt: |
|
cands.append(a) |
|
print "cands", cands |
|
assert len(cands) == 1 |
|
recovered_sk_7_1 = cands[0] |
|
|
|
print "[+]", "recovered SK_7_1", recovered_sk_7_1 |
|
print |
|
|
|
counters = [0] * 64 |
|
total = [0] * 64 |
|
for i, (pt1, pt2, ct1, ct2) in enumerate(pairs): |
|
parity = 0 |
|
ps = 0 |
|
for ct in (ct1, ct2): |
|
bits = str_to_bits(ct) |
|
bits = [bits[x-1] for x in IP_1_inv] |
|
left = bits[:32] |
|
right = bits[32:] |
|
|
|
parity = parity_mask(LMASK, right) ^ left[16] |
|
func = Fmy(right, RANGE6[recovered_sk_7_1], si=0) |
|
ps ^= func[16] ^ parity |
|
|
|
right1 = str_to_bits(str_perm(pt1, IP))[32:] |
|
right2 = str_to_bits(str_perm(pt2, IP))[32:] |
|
|
|
for skn0 in xrange(64): |
|
func1 = Fmy(right1, RANGE6[skn0], si=0) |
|
func2 = Fmy(right2, RANGE6[skn0], si=0) |
|
delta = xor_bits(func1, func2) |
|
if frombin(delta) == 0x00808200: |
|
counters[skn0] += 1 - ps |
|
total[skn0] += 1 |
|
|
|
for i in xrange(64): |
|
counters[i] = float(counters[i])/total[i] |
|
|
|
best = sorted(enumerate(counters), key=lambda (a, b): b) |
|
best_cnt = best[-1][1] |
|
print "best", best_cnt |
|
cands = [] |
|
for a, b in best[-10:]: |
|
print a, b |
|
if b == best_cnt: |
|
cands.append(a) |
|
print "cands", cands |
|
cands_sk_0_1 = cands |
|
|
|
print "[+] cands SK_0_1", cands_sk_0_1 |
|
print |
|
else: |
|
recovered_sk_7_1 = SK7_1_int |
|
cands_sk_0_1 = [SK0_1_int] |
|
print "DEBUG 1" |
|
|
|
|
|
submasks = [] |
|
i = mask = 0x18222828 |
|
while i >= 0: |
|
submasks.append(str_to_bits(("%08x" % i).decode("hex")) + [0] * 32) |
|
# print "%08x" % i |
|
if i == 0: |
|
break |
|
i = (i - 1) & mask |
|
|
|
|
|
if FLAGS[1]: |
|
ROUNDS = 8 |
|
print("STAGE 2") |
|
|
|
N = 2**7 |
|
# N = 2**3 |
|
DELTA = str_to_bits("\x40\x00\x00\x00\x00\x00\x02\x02") |
|
|
|
pts = [] |
|
for i in xrange(N): |
|
if i % 2**3 == 0 and i: |
|
print hex(i) |
|
|
|
base = "".join(chr(randint(0, 255)) for _ in xrange(8)) |
|
base = str_to_bits(base) |
|
pts_cur = [] |
|
for sub in submasks: |
|
pts_cur.append(xor_bits(base, sub)) |
|
for pt in pts_cur[::]: |
|
pts_cur.append(xor_bits(pt, DELTA)) |
|
pts += map(tuple, pts_cur) |
|
|
|
|
|
cts = [] |
|
ctsx = [] |
|
ptsx = [] |
|
for pt in pts: |
|
pt = [pt[x-1] for x in IP_inv] |
|
pt = bits_to_str(pt) |
|
ptsx.append(pt) |
|
|
|
ptsx = "".join(ptsx) |
|
newct = encrypt(ptsx, KEY) |
|
assert len(ptsx) == len(newct) |
|
for i in xrange(0, len(newct), 8): |
|
# ct = encrypt_block(pt, SK) |
|
ct = newct[i:i+8] |
|
bits = str_to_bits(ct) |
|
bits = [bits[x-1] for x in IP_1_inv] |
|
left = bits[:32] |
|
right = bits[32:] |
|
|
|
res = parity_mask(LMASK, left) ^ right[16] |
|
cts.append(res) |
|
ctsx.append(bits) |
|
|
|
print len(pts), len(cts) |
|
|
|
lefts = [] |
|
cache6 = {} |
|
cache8 = {} |
|
for pt in pts: |
|
left, right = pt[:32], pt[32:] |
|
fl = frombin(left) |
|
lefts.append(fl) |
|
for pt in pts[::256]: |
|
left, right = pt[:32], pt[32:] |
|
for skn, sk in enumerate(RANGE6): |
|
blob1 = Fmy(right, sk, si=6-1) |
|
blob2 = Fmy(right, sk, si=8-1) |
|
cache6[pt,skn] = frombin(blob1) |
|
cache8[pt,skn] = frombin(blob2) |
|
|
|
counters = defaultdict(int) |
|
total = defaultdict(int) |
|
for ichunk in xrange(N): |
|
print "chunkA", ichunk+1, "/", N |
|
chunk = zip(pts[ichunk*512:ichunk*512+512], cts[ichunk*512:ichunk*512+512], lefts) |
|
|
|
for skn6 in xrange(64): |
|
for skn8 in xrange(64): |
|
by_left = {} |
|
for i, (pt, ct, left) in enumerate(chunk): |
|
if i == 0 or i == 256: |
|
blob1 = cache6[pt, skn6] |
|
blob2 = cache8[pt, skn8] |
|
func = blob1 ^ blob2 ^ left |
|
if i < 256: |
|
assert func not in by_left |
|
by_left[func] = i |
|
else: |
|
func ^= 0x40000000 |
|
assert func in by_left |
|
j = by_left[func] |
|
counters[skn6,skn8] += 1 - (chunk[i][1] ^ chunk[j][1]) |
|
total[skn6,skn8] += 1 |
|
|
|
best = sorted(counters.items(), key=lambda (a, b): b) |
|
best_cnt = best[-1][1] |
|
print "best", best_cnt |
|
skn68_cands = [] |
|
for a, b in best[-10:]: |
|
print a, b |
|
if b == best_cnt: |
|
skn68_cands.append(a) |
|
print "CANDS", skn68_cands |
|
cands_sk_0_68 = skn68_cands |
|
print "[+] candscands_sk_0_68", cands_sk_0_68 |
|
print |
|
|
|
# recover up |
|
for skn6, skn8 in skn68_cands[:1]: |
|
total = defaultdict(int) |
|
counters = defaultdict(int) |
|
for ichunk in xrange(N): |
|
# print "chunk", ichunk+1, "/", N |
|
chunk = zip(pts[ichunk*512:ichunk*512+512], cts[ichunk*512:ichunk*512+512], lefts) |
|
for skn5 in xrange(64): |
|
by_left = {} |
|
for i, (pt, ct, left) in enumerate(chunk): |
|
if i == 0 or i == 256: |
|
blob1 = cache6[pt, skn6] |
|
blob2 = cache8[pt, skn8] |
|
func = blob1 ^ blob2 ^ left |
|
if i < 256: |
|
assert func not in by_left |
|
by_left[func] = i |
|
else: |
|
func ^= 0x40000000 |
|
assert func in by_left |
|
j = by_left[func] |
|
|
|
parity = 0 |
|
ct1 = ctsx[ichunk*512+i] |
|
ct2 = ctsx[ichunk*512+j] |
|
for ct in (ct1, ct2): |
|
left, right = ct[:32], ct[32:] |
|
func = Fmy(right, RANGE6[skn5], 4) |
|
left = xor_bits(left, func) |
|
parity ^= parity_mask(left, LMASK) |
|
counters[skn5] += 1 ^ parity |
|
total[skn5] += 1 |
|
|
|
best = sorted(counters.items(), key=lambda (a, b): b) |
|
best_cnt = best[-1][1] |
|
print "best", best_cnt |
|
cands = [] |
|
for a, b in best[-10:]: |
|
print a, b |
|
if b == best_cnt: |
|
cands.append(a) |
|
print "cands", cands |
|
assert len(cands) == 1 |
|
recovered_sk_7_5 = cands[0] |
|
print "[+]", "recovered SK_7_5", recovered_sk_7_5 |
|
print |
|
else: |
|
recovered_sk_7_5 = SK7_5_int |
|
cands_sk_0_68 = [(SK0_6_int, SK0_8_int)] |
|
print "DEBUG 2" |
|
|
|
if FLAGS[2]: |
|
''' |
|
00800002 |
|
1 [0, 0, 0, 0, 0, 0] |
|
2 [0, 0, 0, 0, 0, 1] * |
|
2 |
|
3 [0, 1, 0, 0, 0, 0] * |
|
3 |
|
4 [0, 0, 0, 0, 0, 0] |
|
5 [0, 0, 0, 0, 0, 0] |
|
6 [0, 0, 0, 0, 0, 0] |
|
7 [0, 0, 0, 0, 0, 0] |
|
8 [0, 0, 0, 1, 0, 0] * |
|
submask: 0x44094114 |
|
''' |
|
print("STAGE 3") |
|
|
|
submasks = [] |
|
i = mask = 0x44094114 |
|
while i >= 0: |
|
submasks.append(str_to_bits(("%08x" % i).decode("hex")) + [0] * 32) |
|
if i == 0: |
|
break |
|
i = (i - 1) & mask |
|
|
|
ROUNDS = 8 |
|
|
|
N = 2**7 |
|
DELTA_L = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0] |
|
DELTA = str_to_bits("\x40\x00\x00\x00" + "\x00\x80\x00\x02") |
|
DELTA = xor_bits(DELTA[:32], DELTA_L) + DELTA[32:] |
|
|
|
DELTA_L_int = int(bits_to_str(DELTA[:32]).encode("hex"), 16) |
|
print "DELTA", bits_to_str(DELTA).encode("hex") |
|
print "DELTA_L_int %08x" % DELTA_L_int |
|
pts = [] |
|
for i in xrange(N): |
|
# print("prep", i, "/", N) |
|
if i % 2**3 == 0 and i: |
|
print hex(i) |
|
|
|
while True: |
|
base = [randint(0, 1) for _ in xrange(64)] |
|
for skn6, skn8 in cands_sk_0_68: |
|
test1 = Fmy(base[32:], RANGE6[skn8], si=8-1) |
|
test2 = Fmy(xor_bits(base[32:], DELTA[32:]), RANGE6[skn8], si=8-1) |
|
difftest = xor_bits(test1, test2) |
|
# difftest = int(bits_to_str(difftest).encode("hex"), 16) |
|
# print "0x%08x" % difftest, |
|
# difftest &= 0xffffffff ^ mask |
|
# if 1: |
|
# print "0x%08x," % difftest, |
|
# print str_to_bits(("%08x" % difftest).decode("hex")) |
|
if difftest != DELTA_L: |
|
break |
|
else: |
|
break |
|
|
|
pts_cur = [] |
|
for sub in submasks: |
|
pts_cur.append(xor_bits(base, sub)) |
|
for pt in pts_cur[::]: |
|
pts_cur.append(xor_bits(pt, DELTA)) |
|
pts += map(tuple, pts_cur) |
|
|
|
# pt1 = pts_cur[0+100] |
|
# pt2 = pts_cur[256+100] |
|
# test1 = F(pt1[32:], SK[0]) |
|
# test2 = F(pt2[32:], SK[0]) |
|
# difftest = xor_bits(test1, test2) |
|
# difftest = difftest = int(bits_to_str(difftest).encode("hex"), 16) |
|
# print "diff %08x" % difftest, "by mask %08x" % (difftest & mask), "rest %08x" % (difftest & (0xffffffff ^ mask)) |
|
|
|
cts = [] |
|
ctsx = [] |
|
ptsx = [] |
|
for pt in pts: |
|
pt = [pt[x-1] for x in IP_inv] |
|
pt = bits_to_str(pt) |
|
ptsx.append(pt) |
|
|
|
ptsx = "".join(ptsx) |
|
newct = encrypt(ptsx, KEY) |
|
assert len(ptsx) == len(newct) |
|
for i in xrange(0, len(newct), 8): |
|
ct = newct[i:i+8] |
|
bits = str_to_bits(ct) |
|
bits = [bits[x-1] for x in IP_1_inv] |
|
left = bits[:32] |
|
right = bits[32:] |
|
|
|
res = parity_mask(LMASK, left) ^ right[16] |
|
cts.append(res) |
|
ctsx.append(bits) |
|
|
|
print len(pts), len(cts) |
|
|
|
# 6,8 -> 2,3 |
|
lefts = [] |
|
cache6 = {} |
|
cache8 = {} |
|
for pt in pts: |
|
left, right = pt[:32], pt[32:] |
|
fl = frombin(left) |
|
lefts.append(fl) |
|
for pt in pts[::256]: |
|
left, right = pt[:32], pt[32:] |
|
for skn, sk in enumerate(RANGE6): |
|
blob1 = Fmy(right, sk, si=2-1) |
|
blob2 = Fmy(right, sk, si=3-1) |
|
cache6[pt,skn] = frombin(blob1) |
|
cache8[pt,skn] = frombin(blob2) |
|
|
|
counters = defaultdict(int) |
|
total = defaultdict(int) |
|
for ichunk in xrange(N): |
|
print "chunkB", ichunk+1, "/", N |
|
chunk = zip(pts[ichunk*512:ichunk*512+512], cts[ichunk*512:ichunk*512+512], lefts) |
|
|
|
for skn6 in xrange(64): |
|
for skn8 in xrange(64): |
|
by_left = {} |
|
for i, (pt, ct, left) in enumerate(chunk): |
|
if i == 0 or i == 256: |
|
blob1 = cache6[pt, skn6] |
|
blob2 = cache8[pt, skn8] |
|
func = blob1 ^ blob2 ^ left |
|
if i < 256: |
|
assert func not in by_left |
|
by_left[func] = i |
|
else: |
|
func ^= DELTA_L_int |
|
assert func in by_left |
|
j = by_left[func] |
|
counters[skn6,skn8] += 1 - (chunk[i][1] ^ chunk[j][1]) |
|
total[skn6,skn8] += 1 |
|
|
|
best = sorted(counters.items(), key=lambda (a, b): b) |
|
best_cnt = best[-1][1] |
|
print "best", best_cnt |
|
skn68_cands = [] |
|
for a, b in best[-20:]: |
|
print a, b |
|
if b == best_cnt: |
|
skn68_cands.append(a) |
|
print "CANDS", skn68_cands |
|
cands_sk_0_23 = skn68_cands |
|
print "[+] candscands_sk_0_23", cands_sk_0_23 |
|
print |
|
|
|
print |
|
else: |
|
cands_sk_0_23 = [(SK0_2_int, SK0_3_int)] |
|
print "DEBUG3" |
|
|
|
|
|
if FLAGS[3]: |
|
''' |
|
00008200 |
|
1 [0, 0, 0, 0, 0, 0] |
|
2 [0, 0, 0, 0, 0, 0] |
|
3 [0, 0, 0, 0, 0, 0] |
|
4 [0, 0, 0, 0, 0, 1] * |
|
4 |
|
5 [0, 1, 0, 0, 0, 0] * |
|
5 |
|
6 [0, 0, 0, 1, 0, 0] * |
|
7 [0, 0, 0, 0, 0, 0] |
|
8 [0, 0, 0, 0, 0, 0] |
|
submask: 0xa14410c0 |
|
''' |
|
print("STAGE 4") |
|
|
|
submasks = [] |
|
i = mask = 0xa14410c0 |
|
while i >= 0: |
|
submasks.append(str_to_bits(("%08x" % i).decode("hex")) + [0] * 32) |
|
if i == 0: |
|
break |
|
i = (i - 1) & mask |
|
|
|
ROUNDS = 8 |
|
|
|
N = 2**7 |
|
DELTA_L = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] |
|
DELTA = str_to_bits("\x40\x00\x00\x00" + "\x00\x00\x82\x00") |
|
DELTA = xor_bits(DELTA[:32], DELTA_L) + DELTA[32:] |
|
|
|
DELTA_L_int = int(bits_to_str(DELTA[:32]).encode("hex"), 16) |
|
print "DELTA", bits_to_str(DELTA).encode("hex") |
|
print "DELTA_L_int %08x" % DELTA_L_int |
|
pts = [] |
|
for i in xrange(N): |
|
if i % 2**3 == 0 and i: |
|
print hex(i) |
|
|
|
while True: |
|
base = [randint(0, 1) for _ in xrange(64)] |
|
for skn6, skn8 in cands_sk_0_68: |
|
test1 = Fmy(base[32:], RANGE6[skn6], si=6-1) |
|
test2 = Fmy(xor_bits(base[32:], DELTA[32:]), RANGE6[skn6], si=6-1) |
|
difftest = xor_bits(test1, test2) |
|
# difftest = int(bits_to_str(difftest).encode("hex"), 16) |
|
# print "0x%08x" % difftest, |
|
# difftest &= 0xffffffff ^ mask |
|
# if 1: |
|
# print "0x%08x," % difftest, |
|
# print str_to_bits(("%08x" % difftest).decode("hex")) |
|
# print difftest |
|
if difftest != DELTA_L: |
|
break |
|
else: |
|
break |
|
|
|
pts_cur = [] |
|
for sub in submasks: |
|
pts_cur.append(xor_bits(base, sub)) |
|
for pt in pts_cur[::]: |
|
pts_cur.append(xor_bits(pt, DELTA)) |
|
pts += map(tuple, pts_cur) |
|
|
|
# pt1 = pts_cur[0+100] |
|
# pt2 = pts_cur[256+100] |
|
# test1 = F(pt1[32:], SK[0]) |
|
# test2 = F(pt2[32:], SK[0]) |
|
# difftest = xor_bits(test1, test2) |
|
# difftest = difftest = int(bits_to_str(difftest).encode("hex"), 16) |
|
# print "diff %08x" % difftest, "by mask %08x" % (difftest & mask), "rest %08x" % (difftest & (0xffffffff ^ mask)) |
|
|
|
cts = [] |
|
ctsx = [] |
|
ptsx = [] |
|
for pt in pts: |
|
pt = [pt[x-1] for x in IP_inv] |
|
pt = bits_to_str(pt) |
|
ptsx.append(pt) |
|
|
|
ptsx = "".join(ptsx) |
|
newct = encrypt(ptsx, KEY) |
|
assert len(ptsx) == len(newct) |
|
for i in xrange(0, len(newct), 8): |
|
ct = newct[i:i+8] |
|
bits = str_to_bits(ct) |
|
bits = [bits[x-1] for x in IP_1_inv] |
|
left = bits[:32] |
|
right = bits[32:] |
|
|
|
res = parity_mask(LMASK, left) ^ right[16] |
|
cts.append(res) |
|
ctsx.append(bits) |
|
|
|
print len(pts), len(cts) |
|
|
|
# 6,8 -> 4,5 |
|
lefts = [] |
|
cache6 = {} |
|
cache8 = {} |
|
for pt in pts: |
|
left, right = pt[:32], pt[32:] |
|
fl = frombin(left) |
|
lefts.append(fl) |
|
for pt in pts[::256]: |
|
left, right = pt[:32], pt[32:] |
|
for skn, sk in enumerate(RANGE6): |
|
blob1 = Fmy(right, sk, si=4-1) |
|
blob2 = Fmy(right, sk, si=5-1) |
|
cache6[pt,skn] = frombin(blob1) |
|
cache8[pt,skn] = frombin(blob2) |
|
|
|
counters = defaultdict(int) |
|
total = defaultdict(int) |
|
for ichunk in xrange(N): |
|
print "chunkC", ichunk+1, "/", N |
|
chunk = zip(pts[ichunk*512:ichunk*512+512], cts[ichunk*512:ichunk*512+512], lefts) |
|
|
|
for skn6 in xrange(64): |
|
for skn8 in xrange(64): |
|
by_left = {} |
|
for i, (pt, ct, left) in enumerate(chunk): |
|
if i == 0 or i == 256: |
|
blob1 = cache6[pt, skn6] |
|
blob2 = cache8[pt, skn8] |
|
func = blob1 ^ blob2 ^ left |
|
if i < 256: |
|
assert func not in by_left |
|
by_left[func] = i |
|
else: |
|
func ^= DELTA_L_int |
|
assert func in by_left |
|
j = by_left[func] |
|
counters[skn6,skn8] += 1 - (chunk[i][1] ^ chunk[j][1]) |
|
total[skn6,skn8] += 1 |
|
|
|
best = sorted(counters.items(), key=lambda (a, b): b) |
|
best_cnt = best[-1][1] |
|
print "best", best_cnt |
|
skn68_cands = [] |
|
for a, b in best[-20:]: |
|
print a, b |
|
if b == best_cnt: |
|
skn68_cands.append(a) |
|
print "CANDS", skn68_cands |
|
cands_sk_0_45 = skn68_cands |
|
print "[+] candscands_sk_0_45", cands_sk_0_45 |
|
print |
|
|
|
print |
|
else: |
|
cands_sk_0_45 = [(SK0_4_int, SK0_5_int)] |
|
print "DEBUG3" |
|
|
|
|
|
|
|
# recovered_sk_7_1 |
|
# recovered_sk_7_5 |
|
# cands_sk_0_1 = cands |
|
# cands_sk_0_68 |
|
|
|
def setkey(i, j, val): |
|
global MASTER_KEY |
|
j -= 1 |
|
inds = SUBKEY_POS[i][j*6:j*6+6] |
|
# print i, j, inds |
|
for ii, ind in enumerate(inds): |
|
MASTER_KEY[ind] = (val >> (5 - ii)) & 1 |
|
|
|
print("STAGE 5 - final bruteforce") |
|
pt = "omgwtfxz" |
|
ct = encrypt(pt, KEY) |
|
print "pt = %s" % pt.encode("hex") |
|
print "ct = %s" % ct.encode("hex") |
|
|
|
print KEY.encode("hex") |
|
bits = str_to_bits(KEY) |
|
key = frombin(bits) |
|
print "SECRET KEY %016x" % key |
|
for recovered_sk_0_1 in cands_sk_0_1: |
|
for recovered_sk_0_6, recovered_sk_0_8 in cands_sk_0_68: |
|
for recovered_sk_0_2, recovered_sk_0_3 in cands_sk_0_23: |
|
for recovered_sk_0_4, recovered_sk_0_5 in cands_sk_0_45: |
|
MASTER_KEY = [None] * 64 |
|
for i in xrange(7, 64, 8): |
|
MASTER_KEY[i] = 0 |
|
|
|
setkey(0, 1, recovered_sk_0_1) |
|
setkey(0, 2, recovered_sk_0_2) |
|
setkey(0, 3, recovered_sk_0_3) |
|
setkey(0, 4, recovered_sk_0_4) |
|
setkey(0, 5, recovered_sk_0_5) |
|
setkey(0, 6, recovered_sk_0_6) |
|
setkey(0, 8, recovered_sk_0_8) |
|
setkey(7, 1, recovered_sk_7_1) |
|
setkey(7, 5, recovered_sk_7_5) |
|
|
|
# print Counter(MASTER_KEY) |
|
mask_unknown = frombin([int(v == None) for v in MASTER_KEY]) |
|
mask_known = frombin([int(v != None) for v in MASTER_KEY]) |
|
val_known = frombin([int(v) if v != None else 0 for v in MASTER_KEY ]) |
|
print Counter(MASTER_KEY) |
|
print "known = %016x & %016x" % (val_known, mask_known) |
|
print "unknown = %016x" % mask_unknown |
|
print "secret check %016x (if local)" % (key & mask_known), (key & mask_known) == val_known |
|
|
|
print "brute..." |
|
test_key = mask_unknown |
|
while test_key > 0: |
|
kkkey = test_key | val_known |
|
kkkey = ("%016x" % kkkey).decode("hex") |
|
test_ct = encrypt_local(pt, kkkey) |
|
if test_ct == ct: |
|
print "PWND", kkkey.encode("hex") |
|
print "pt", pt.encode("hex") |
|
print "ct", ct.encode("hex") |
|
print "test_ct", test_ct.encode("hex") |
|
f1.write("flagflag" + kkkey) |
|
f1.flush() |
|
quit() |
|
test_key = (test_key - 1) & mask_unknown |
|
print |
|
|
|
print "no luck" |
|
# print `f.buf` |
|
# f.interact() |
|
|
|
#flag{but_th3_litt1e_sticky_leave5_and_tHe_pRec1ous_t0mbs_and_the_b1Ue_sky_ANd_The_woman_you_loVe_How_will_you_lIVe_h0w_wi1l_yoU_love_them_wiTh_5uch_a_h3ll_in_YouR_h34rt_and_y0ur_he4d__How__Can__You__} |
|
# 5min locally |