Skip to content

Instantly share code, notes, and snippets.

@hellman
Last active December 13, 2020 17:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hellman/c0245e67abd9b70dc652b813ddb45db0 to your computer and use it in GitHub Desktop.
Save hellman/c0245e67abd9b70dc652b813ddb45db0 to your computer and use it in GitHub Desktop.
ASIS CTF 2020 Finals - Trio Color (3DES)

ASIS CTF 2020 Finals - TriColor (3DES)

The challenge leaks the value after 2DES, so we can run MitM on the first two rounds (each has 32-bit key), and then run differential-linear attack on the 3-rd DES (8 rounds).

Current configuration requires 16 GiB of working RAM, but can be modified to reduce. MitM stage2 takes max ~25 minutes on my laptop (1 core). Dif.-linear attack takes ~15 minutes including communications.

How to prepare:

$ make compile  # generate tables and compile C++ MitM part
$ make precomp  # ~1 hour to precompute 65 GiB data for MitM; fast SSD is recommended ;)
$ make prepare  # create some pipes

Now, run the attack:

$ python3 solve.py
...
please run:
stage2.cpp 0fd7fd5ffdf4 1faa03c5a9fc
input?

Read arguments and run stage2. Fast SSD is recommended ;)

$ ./stage2 0fd7fd5ffdf4 1faa03c5a9fc
sieving from dump3
++++++++++++++FOUND fa49aaf7 baac4770
+++++FOUND f3bd18f0 d8ecd729

It may spit several key candidates, paste to the solve.py script to check, it will check for false positives:

input?
fa49aaf7 baac4770          
enc 0 -> ac57dedfd797301d vs ac57dedfd797301d
enc 1 -> 30a403dad1ce199b vs dda7cfceb1cf2ead
ouch
input?
f3bd18f0 d8ecd729
enc 0 -> ac57dedfd797301d vs ac57dedfd797301d
enc 1 -> dda7cfceb1cf2ead vs dda7cfceb1cf2ead
phew match
check? b'\x01\x02\x03\x04\x05\x06\x07\x08' b'\x01\x02\x03\x04\x05\x06\x07\x08' True
z b'\x99\x9fP\x89\x95/E\xac'

When accepted, run in parallel des.py to mount the differential linear attack. This is an old uncleaned code from zer0des task (0CTF quals 2019)...

pypy2 des.py

Sometimes the attack fails, run the whole thing again.

Wait for the flag in solve.py session.

GOT KEY b'<\x12Lx\xbc\x16N\xfc' 3c124c78bc164efc
...
| Send third round key in hex format.
| Congratz! You got the flag: ASIS{T0_ki55_iN_car5_4nd_d0wnTown_baRs___W4s_aLL_w3_need3d____Y0u_dr3w_s7ars_ar0Und_mY_sc4rs___Bu7_nOw_I_am_bl33ding!}
#include <bits/stdc++.h>
#include <unistd.h>
#include <string.h>
#include <assert.h>
using namespace std;
// some competitive prog. macros
#define whole(x) (x).begin(),(x).end()
#define __CONCAT3_NX(x, y, z) x ## y ## z
#define __CONCAT3(x, y, z) __CONCAT3_NX(x, y, z)
#define __VAR(name) __CONCAT3(__tmpvar__, name, __LINE__)
#define __TYPE(x) __typeof(x)
#define FOR(i, s, n) for (__TYPE(n) i=(s), __VAR(end)=(n); i < __VAR(end); i++)
#define RFOR(i, s, n) for (__TYPE(n) i=(n)-1, __VAR(end)=(s); i >= __VAR(end); i--)
#define FORN(i, n) FOR(i, 0, n)
#define RFORN(i, n) RFOR(i, 0, n)
#include "tab.h"
uint64_t l1, r1, l2, r2;
inline uint64_t f(uint64_t x) {
uint64_t y = 0;
FORN(i, 8) {
y |= TABF[i][x & 0x3f];
x >>= 6;
}
assert(x == 0);
return y;
}
void encrypt(uint64_t &l, uint64_t &r, uint64_t *keys) {
l ^= f(r ^ keys[0]);
r ^= f(l ^ keys[1]);
l ^= f(r ^ keys[2]);
r ^= f(l ^ keys[3]);
l ^= f(r ^ keys[4]);
r ^= f(l ^ keys[5]);
l ^= f(r ^ keys[6]);
r ^= f(l ^ keys[7]);
swap(l, r);
}
void decrypt(uint64_t &l, uint64_t &r, uint64_t *keys) {
swap(l, r);
r ^= f(l ^ keys[7]);
l ^= f(r ^ keys[6]);
r ^= f(l ^ keys[5]);
l ^= f(r ^ keys[4]);
r ^= f(l ^ keys[3]);
l ^= f(r ^ keys[2]);
r ^= f(l ^ keys[1]);
l ^= f(r ^ keys[0]);
}
inline void expand(uint32_t seed, uint64_t *curkeys) {
memcpy(curkeys, TABKS_ZERO, sizeof(TABKS_ZERO));
FORN(j, 4) {
int x = seed & 0xff;
seed >>= 8;
FORN(i, 8) {
curkeys[i] ^= TABKS[j][x][i];
}
}
}
inline void enc(uint64_t &l, uint64_t &r, uint32_t seed) {
uint64_t curkeys[16] = {};
expand(seed, curkeys);
encrypt(l, r, curkeys);
}
inline void dec(uint64_t &l, uint64_t &r, uint32_t seed) {
uint64_t curkeys[16] = {};
expand(seed, curkeys);
decrypt(l, r, curkeys);
}
inline int getflag(uint64_t l, uint64_t r) {
return (__builtin_popcountll(l) ^ __builtin_popcountll(r)) & 1;
}
inline int bitpos(uint32_t x) {
int res = __builtin_ffs(x) - 1;
assert(res != -1);
assert(x == 1 << res);
return res;
}
from collections import defaultdict, Counter
from subprocess import Popen, PIPE
LOCAL = False
# FLAGS = 0, 0, 0, 1
# LOCAL = False
FLAGS = 1, 1, 1, 1
ROUNDS = 8
IP = [58, 50, 42, 34, 26, 18, 10, 2, 60, 52, 44, 36, 28, 20, 12, 4, 62, 54, 46, 38, 30, 22, 14, 6, 64, 56, 48, 40, 32, 24, 16, 8, 57, 49, 41, 33, 25, 17, 9, 1, 59, 51, 43, 35, 27, 19, 11, 3, 61, 53, 45, 37, 29, 21, 13, 5, 63, 55, 47, 39, 31, 23, 15, 7]
IP_1 = [40, 8, 48, 16, 56, 24, 64, 32, 39, 7, 47, 15, 55, 23, 63, 31, 38, 6, 46, 14, 54, 22, 62, 30, 37, 5, 45, 13, 53, 21, 61, 29, 36, 4, 44, 12, 52, 20, 60, 28, 35, 3, 43, 11, 51, 19, 59, 27, 34, 2, 42, 10, 50, 18, 58, 26, 33, 1, 41, 9, 49, 17, 57, 25]
IP_inv = [IP.index(i)+1 for i in xrange(1, 65)]
IP_1_inv = [IP_1.index(i)+1 for i in xrange(1, 65)]
E = [32, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 8, 9, 10, 11, 12, 13, 12, 13, 14, 15, 16, 17, 16, 17, 18, 19, 20, 21, 20, 21, 22, 23, 24, 25, 24, 25, 26, 27, 28, 29, 28, 29, 30, 31, 32, 1]
SBOX = [[[14, 4, 13, 1, 2, 15, 11, 8, 3, 10, 6, 12, 5, 9, 0, 7], [0, 15, 7, 4, 14, 2, 13, 1, 10, 6, 12, 11, 9, 5, 3, 8], [4, 1, 14, 8, 13, 6, 2, 11, 15, 12, 9, 7, 3, 10, 5, 0], [15, 12, 8, 2, 4, 9, 1, 7, 5, 11, 3, 14, 10, 0, 6, 13]], [[15, 1, 8, 14, 6, 11, 3, 4, 9, 7, 2, 13, 12, 0, 5, 10], [3, 13, 4, 7, 15, 2, 8, 14, 12, 0, 1, 10, 6, 9, 11, 5], [0, 14, 7, 11, 10, 4, 13, 1, 5, 8, 12, 6, 9, 3, 2, 15], [13, 8, 10, 1, 3, 15, 4, 2, 11, 6, 7, 12, 0, 5, 14, 9]], [[10, 0, 9, 14, 6, 3, 15, 5, 1, 13, 12, 7, 11, 4, 2, 8], [13, 7, 0, 9, 3, 4, 6, 10, 2, 8, 5, 14, 12, 11, 15, 1], [13, 6, 4, 9, 8, 15, 3, 0, 11, 1, 2, 12, 5, 10, 14, 7], [1, 10, 13, 0, 6, 9, 8, 7, 4, 15, 14, 3, 11, 5, 2, 12]], [[7, 13, 14, 3, 0, 6, 9, 10, 1, 2, 8, 5, 11, 12, 4, 15], [13, 8, 11, 5, 6, 15, 0, 3, 4, 7, 2, 12, 1, 10, 14, 9], [10, 6, 9, 0, 12, 11, 7, 13, 15, 1, 3, 14, 5, 2, 8, 4], [3, 15, 0, 6, 10, 1, 13, 8, 9, 4, 5, 11, 12, 7, 2, 14]], [[2, 12, 4, 1, 7, 10, 11, 6, 8, 5, 3, 15, 13, 0, 14, 9], [14, 11, 2, 12, 4, 7, 13, 1, 5, 0, 15, 10, 3, 9, 8, 6], [4, 2, 1, 11, 10, 13, 7, 8, 15, 9, 12, 5, 6, 3, 0, 14], [11, 8, 12, 7, 1, 14, 2, 13, 6, 15, 0, 9, 10, 4, 5, 3]], [[12, 1, 10, 15, 9, 2, 6, 8, 0, 13, 3, 4, 14, 7, 5, 11], [10, 15, 4, 2, 7, 12, 9, 5, 6, 1, 13, 14, 0, 11, 3, 8], [9, 14, 15, 5, 2, 8, 12, 3, 7, 0, 4, 10, 1, 13, 11, 6], [4, 3, 2, 12, 9, 5, 15, 10, 11, 14, 1, 7, 6, 0, 8, 13]], [[4, 11, 2, 14, 15, 0, 8, 13, 3, 12, 9, 7, 5, 10, 6, 1], [13, 0, 11, 7, 4, 9, 1, 10, 14, 3, 5, 12, 2, 15, 8, 6], [1, 4, 11, 13, 12, 3, 7, 14, 10, 15, 6, 8, 0, 5, 9, 2], [6, 11, 13, 8, 1, 4, 10, 7, 9, 5, 0, 15, 14, 2, 3, 12]], [[13, 2, 8, 4, 6, 15, 11, 1, 10, 9, 3, 14, 5, 0, 12, 7], [1, 15, 13, 8, 10, 3, 7, 4, 12, 5, 6, 11, 0, 14, 9, 2], [7, 11, 4, 1, 9, 12, 14, 2, 0, 6, 10, 13, 15, 3, 5, 8], [2, 1, 14, 7, 4, 10, 8, 13, 15, 12, 9, 0, 3, 5, 6, 11]]]
P = [16, 7, 20, 21, 29, 12, 28, 17, 1, 15, 23, 26, 5, 18, 31, 10, 2, 8, 24, 14, 32, 27, 3, 9, 19, 13, 30, 6, 22, 11, 4, 25]
P_inv = [P.index(i)+1 for i in xrange(1, 33)]
PC_1 = [57, 49, 41, 33, 25, 17, 9, 1, 58, 50, 42, 34, 26, 18, 10, 2, 59, 51, 43, 35, 27, 19, 11, 3, 60, 52, 44, 36, 63, 55, 47, 39, 31, 23, 15, 7, 62, 54, 46, 38, 30, 22, 14, 6, 61, 53, 45, 37, 29, 21, 13, 5, 28, 20, 12, 4]
PC_2 = [14, 17, 11, 24, 1, 5, 3, 28, 15, 6, 21, 10, 23, 19, 12, 4, 26, 8, 16, 7, 27, 20, 13, 2, 41, 52, 31, 37, 47, 55, 30, 40, 51, 45, 33, 48, 44, 49, 39, 56, 34, 53, 46, 42, 50, 36, 29, 32]
R = [1,1,2,2,2,2,2,2,1,2,2,2,2,2,2,1]
def chr_to_bits(c):
res = bin(ord(c))[2:]
return map(int, list(res.rjust(8,'0')))
def str_to_bits(s):
res = []
for c in s:
res.extend(chr_to_bits(c))
return res
def bits_to_chr(bits):
res = int(''.join(map(str, bits)), 2)
return chr(res)
def bits_to_str(bits):
res = ''
for i in range(0, len(bits), 8):
res += bits_to_chr(bits[i:i+8])
return res
def xor_bits(l,r):
return [a ^ b for a, b in zip(l, r)]
# return map(lambda (x,y):x^y, zip(l,r))
def parity_mask(word, mask):
return sum(a & b for a, b in zip(word, mask)) & 1
def F(hblk, subkey):
bits = [hblk[x-1] for x in E]
bits = xor_bits(bits, subkey)
res = []
for i in range(0, len(bits), 6):
# row = bits[i]*2+bits[i+5]
# col = bits[i+1]*8+bits[i+2]*4+bits[i+3]*2+bits[i+4]
# val = bin(SBOX[i/6][row][col])[2:]
# res.extend(map(int, list(val.rjust(4,'0'))))
res.extend(MY_SBOX_BITS[i/6][tuple(bits[i:i+6])])
res = [res[x-1] for x in P]
return res
def encrypt_block(blk, subkeys):
assert len(blk)==8
bits = str_to_bits(blk)
bits = [bits[x-1] for x in IP]
for i in range(ROUNDS):
left = bits[:32]
right = bits[32:]
left = xor_bits(left, F(right, subkeys[i]))
bits = right + left
bits = left + right
bits = [bits[x-1] for x in IP_1]
return bits_to_str(bits)
def gen_subkey(key):
kbits = str_to_bits(key)
kbits = [kbits[x-1] for x in PC_1]
left = kbits[:28]
right = kbits[28:]
subkeys = []
for i in range(ROUNDS):
left = left[R[i]:]+left[:R[i]]
right = right[R[i]:]+right[:R[i]]
cur = left+right
subkeys.append([cur[x-1] for x in PC_2])
# if subkeys[0] == subkeys[1] or subkeys[0] == subkeys[2]:
# raise Exception("Boom")
return subkeys
def gen_subkey_dbg():
kbits = range(64)
kbits = [kbits[x-1] for x in PC_1]
left = kbits[:28]
right = kbits[28:]
subkeys = []
for i in range(ROUNDS):
left = left[R[i]:]+left[:R[i]]
right = right[R[i]:]+right[:R[i]]
cur = left+right
subkeys.append([cur[x-1] for x in PC_2])
# if subkeys[0] == subkeys[1] or subkeys[0] == subkeys[2]:
# raise Exception("Boom")
return subkeys
def encrypt(pt, key):
assert len(pt)%8==0
subkeys = gen_subkey(key)
ct = ''
for i in range(0, len(pt), 8):
ct += encrypt_block(pt[i:i+8], subkeys)
return ct
encrypt_local = encrypt
def decrypt_block(blk, subkeys):
assert len(blk)==8
bits = str_to_bits(blk)
bits = [bits[x-1] for x in IP]
for i in range(ROUNDS):
left = bits[:32]
right = bits[32:]
left = xor_bits(left, F(right, subkeys[ROUNDS-1-i]))
bits = right + left
bits = left + right
bits = [bits[x-1] for x in IP_1]
return bits_to_str(bits)
def decrypt(ct, key):
assert len(ct)%8==0
subkeys = gen_subkey(key)
pt = ''
for i in range(0, len(ct), 8):
pt += decrypt_block(ct[i:i+8], subkeys)
return pt
from itertools import product
MY_SBOX_BITS = []
for si, s in enumerate(SBOX):
new_s = {}
for num, bits in enumerate(product(range(2), repeat=6)):
i = 0
row = bits[i]*2+bits[i+5]
col = bits[i+1]*8+bits[i+2]*4+bits[i+3]*2+bits[i+4]
val = bin(SBOX[si][row][col])[2:]
res = map(int, list(val.rjust(4,'0')))
new_s[tuple(bits)] = res
MY_SBOX_BITS.append(new_s)
# ROUNDS = 16
try:
from Crypto.Cipher import DES
c = DES.new("omgwtfxz")
except:
pass
else:
ROUNDS = 16
assert encrypt("abcdfuck", "omgwtfxz").encode("hex") == c.encrypt("abcdfuck").encode("hex")
print "IMPL OK!"
# quit()
ROUNDS = 8
ROUNDS = 8
from random import *
def xor(a, b):
return "".join([chr(ord(a[i]) ^ ord(b[i % len(b)])) for i in xrange(len(a))])
def Fmy(hblk, sk, si):
bits = [hblk[x-1] for x in E[si*6:si*6+6]]
bits = xor_bits(bits, sk)
# i = 0
# row = bits[i]*2+bits[i+5]
# col = bits[i+1]*8+bits[i+2]*4+bits[i+3]*2+bits[i+4]
# val = bin(SBOX[si][row][col])[2:]
# res = map(int, list(val.rjust(4,'0')))
res = MY_SBOX_BITS[si][tuple(bits)]
word = [0] * 32
word[si*4:si*4+4] = res
word = [word[x-1] for x in P]
return word
def str_perm(s, p):
s = str_to_bits(s)
s = [s[x-1] for x in p]
s = bits_to_str(s)
return s
# d = str_perm("\x00\x80\x82\x00", E)
# d = str_perm("\x00\x80\x82\x00", E)
# d = str_perm("\x00\x00\x02\x02", E)
# d = str_perm("\x00\x00\x80\x00", E)
# d = str_perm("\x60\x00\x00\x00", E)
# d = str_to_bits(d)
# for i in xrange(0, len(d), 6):
# print d[i:i+6]
# quit()
# d = [0] * 32
# d[5*4:5*4+4] = [1] * 4
# d = [d[x-1] for x in P]
# print bits_to_str(d).encode("hex")
# d = [0] * 32
# d[7*4:7*4+4] = [1] * 4
# d = [d[x-1] for x in P]
# print bits_to_str(d).encode("hex")
# quit()
# 10202008
# 08020820
if LOCAL:
KEY = "omgwtfxz"
KEY = "\x97\xba\x95\xbd\x1b\xc5\xac\xd1"
KEY = "".join(chr(ord(c)^(ord(c)&1)) for c in KEY)
SK = gen_subkey(KEY)
else:
KEY = "omgwtfxz"
SK = gen_subkey(KEY)
# pipes...
f1 = open("p1", "wb")
f2 = open("p2", "rb")
# f = Sock("111.186.56.54 10001")
# f.send_line("C(FL")
BLOCK = 2**16
# BLOCK = (20000-8)/2
# while BLOCK % 8:
# BLOCK -= 1
from struct import pack
def encrypt(pt, k):
print "QUERY", len(pt)
fullct = ""
for i in xrange(0, len(pt), BLOCK):
# f.read_until("plaintext(hex): ")
# f.send_line(pt[i:i+BLOCK].encode("hex"))
block = pt[i:i+BLOCK]
print "query chunk", i, len(block)
f1.write(pack("<Q", len(block)) + block)
f1.flush()
ct = b""
while len(ct) < len(block):
ct += f2.read(len(block) - len(ct))
# ct = f.read_line().strip()
# ct = ct.decode("hex")
# assert len(pt[i:i+BLOCK]) == len(ct)
fullct += ct
print "GOT"
return fullct
# pt = "6f6d67777466787a".decode("hex")
# ct = "d65100cf23f0106d".decode("hex")
# key = "d6fafeba9e9cb2b2".decode("hex")
# print encrypt(pt, key) == ct
# print encrypt(pt, key).encode("hex")
# quit
# pt = "omgwtfxz"
# # KEY = "\x00" * 8
# ROUNDS = 8
# ct = encrypt(pt, KEY)
# print "pt = %s" % pt.encode("hex")
# print "ct = %s" % ct.encode("hex")
# print "key= %s" % KEY.encode("hex")
# quit()
pairs = []
PT_DELTA_P = "0080820060000000".decode("hex")
PT_DELTA_Pinv = str_perm(PT_DELTA_P, IP_inv)
RANGE6 = list(product(range(2), repeat=6))
LMASK = str_to_bits("21040080".decode("hex"))
def get_SK(i, j):
return SK[i][j*6:j*6+6]
def frombin(v):
return int("".join(map(str, v)), 2)
SUBKEY_POS = gen_subkey_dbg()
SK0_1 = get_SK(0, 0)
SK0_1_int = ord(bits_to_chr(SK0_1))
SK7_1 = get_SK(7, 0)
SK7_1_int = ord(bits_to_chr(SK7_1))
SK7_5 = get_SK(7, 4)
SK7_5_int = ord(bits_to_chr(SK7_5))
SK0_2 = get_SK(0, 2-1)
SK0_2_int = ord(bits_to_chr(SK0_2))
SK0_3 = get_SK(0, 3-1)
SK0_3_int = ord(bits_to_chr(SK0_3))
SK0_4 = get_SK(0, 4-1)
SK0_4_int = ord(bits_to_chr(SK0_4))
SK0_5 = get_SK(0, 5-1)
SK0_5_int = ord(bits_to_chr(SK0_5))
SK0_6 = get_SK(0, 6-1)
SK0_6_int = ord(bits_to_chr(SK0_6))
SK0_8 = get_SK(0, 8-1)
SK0_8_int = ord(bits_to_chr(SK0_8))
print "sk0_1:", SK0_1_int, SK0_1
print "sk7_1:", SK7_1_int, SK7_1
print "sk7_5:", SK7_5_int, SK7_5
print "sk0_2:", SK0_2_int, SK0_2
print "sk0_3:", SK0_3_int, SK0_3
print "sk0_4:", SK0_4_int, SK0_4
print "sk0_5:", SK0_5_int, SK0_5
print "sk0_6:", SK0_6_int, SK0_6
print "sk0_8:", SK0_8_int, SK0_8
print
if 0:
diff = str_to_bits("\x40\x00\x00\x00")
k = [randint(0, 1) for _ in xrange(48)]
bits1 = [randint(0, 1) for _ in xrange(32)]
bits2 = xor_bits(bits1, diff)
out1 = F(bits1, k)
out2 = F(bits2, k)
delta = xor_bits(out1, out2)
print bits_to_str(delta).encode("hex")
delta = [delta[x-1] for x in E]
bits = [0] * 32
for i in xrange(0, len(delta), 6):
print i/6+1, delta[i:i+6],
if sum(delta[i:i+6]):
print "*"
else:
print
if sum(delta[i:i+6]) and i/6+1 not in (1,6,8):
print i/6+1
bits[i/6*4:i/6*4+4] = [1] * 4
bits = [bits[x-1] for x in P]
print "submask: 0x%s" % bits_to_str(bits).encode("hex")
# quit()
# submasks = []
# i = mask = 0x18222828
# while i >= 0:
# submasks.append(str_to_bits(("%08x" % i).decode("hex")) + [0] * 32)
# # print "%08x" % i
# if i == 0:
# break
# i = (i - 1) & mask
#00800202
#00008202
#00808000
#00800002
#00008002
#00008200
# print bits_to_str(xor_bits(out1, out2)).encode("hex")
# quit()
# for i, sk in enumerate(SUBKEY_POS):
# print i, sk, len(sk)
if FLAGS[0]:
print("STAGE 1")
ROUNDS = 8
N = 2**14
pts = []
for i in xrange(N):
pt1 = "".join(chr(randint(0, 255)) for _ in xrange(8))
pt2 = xor(pt1, PT_DELTA_P)
pt1 = str_perm(pt1, IP_inv)
pt2 = str_perm(pt2, IP_inv)
pts.append(pt1)
pts.append(pt2)
cts = []
ct = encrypt("".join(pts), KEY)
for i in xrange(0, len(ct), 8):
cts.append(ct[i:i+8])
pairs = zip(pts[::2], pts[1::2], cts[::2], cts[1::2])
counters = [0] * 2**6
for i, (pt1, pt2, ct1, ct2) in enumerate(pairs):
if i % 2**12 == 0:
print hex(i)
ps = [0] * 2**6
for ct in (ct1, ct2):
bits = str_to_bits(ct)
bits = [bits[x-1] for x in IP_1_inv]
left = bits[:32]
right = bits[32:]
parity = parity_mask(LMASK, right) ^ left[16]
for skn, sk in enumerate(RANGE6):
func = Fmy(right, sk, si=0)
ps[skn] ^= func[16] ^ parity
for j in xrange(64):
counters[j] += 1 - ps[j]
best = sorted(enumerate(counters), key=lambda (a, b): b)
best_cnt = best[-1][1]
print "best", best_cnt
cands = []
for a, b in best[-10:]:
print a, b
if b == best_cnt:
cands.append(a)
print "cands", cands
assert len(cands) == 1
recovered_sk_7_1 = cands[0]
print "[+]", "recovered SK_7_1", recovered_sk_7_1
print
counters = [0] * 64
total = [0] * 64
for i, (pt1, pt2, ct1, ct2) in enumerate(pairs):
parity = 0
ps = 0
for ct in (ct1, ct2):
bits = str_to_bits(ct)
bits = [bits[x-1] for x in IP_1_inv]
left = bits[:32]
right = bits[32:]
parity = parity_mask(LMASK, right) ^ left[16]
func = Fmy(right, RANGE6[recovered_sk_7_1], si=0)
ps ^= func[16] ^ parity
right1 = str_to_bits(str_perm(pt1, IP))[32:]
right2 = str_to_bits(str_perm(pt2, IP))[32:]
for skn0 in xrange(64):
func1 = Fmy(right1, RANGE6[skn0], si=0)
func2 = Fmy(right2, RANGE6[skn0], si=0)
delta = xor_bits(func1, func2)
if frombin(delta) == 0x00808200:
counters[skn0] += 1 - ps
total[skn0] += 1
for i in xrange(64):
counters[i] = float(counters[i])/total[i]
best = sorted(enumerate(counters), key=lambda (a, b): b)
best_cnt = best[-1][1]
print "best", best_cnt
cands = []
for a, b in best[-10:]:
print a, b
if b == best_cnt:
cands.append(a)
print "cands", cands
cands_sk_0_1 = cands
print "[+] cands SK_0_1", cands_sk_0_1
print
else:
recovered_sk_7_1 = SK7_1_int
cands_sk_0_1 = [SK0_1_int]
print "DEBUG 1"
submasks = []
i = mask = 0x18222828
while i >= 0:
submasks.append(str_to_bits(("%08x" % i).decode("hex")) + [0] * 32)
# print "%08x" % i
if i == 0:
break
i = (i - 1) & mask
if FLAGS[1]:
ROUNDS = 8
print("STAGE 2")
N = 2**7
# N = 2**3
DELTA = str_to_bits("\x40\x00\x00\x00\x00\x00\x02\x02")
pts = []
for i in xrange(N):
if i % 2**3 == 0 and i:
print hex(i)
base = "".join(chr(randint(0, 255)) for _ in xrange(8))
base = str_to_bits(base)
pts_cur = []
for sub in submasks:
pts_cur.append(xor_bits(base, sub))
for pt in pts_cur[::]:
pts_cur.append(xor_bits(pt, DELTA))
pts += map(tuple, pts_cur)
cts = []
ctsx = []
ptsx = []
for pt in pts:
pt = [pt[x-1] for x in IP_inv]
pt = bits_to_str(pt)
ptsx.append(pt)
ptsx = "".join(ptsx)
newct = encrypt(ptsx, KEY)
assert len(ptsx) == len(newct)
for i in xrange(0, len(newct), 8):
# ct = encrypt_block(pt, SK)
ct = newct[i:i+8]
bits = str_to_bits(ct)
bits = [bits[x-1] for x in IP_1_inv]
left = bits[:32]
right = bits[32:]
res = parity_mask(LMASK, left) ^ right[16]
cts.append(res)
ctsx.append(bits)
print len(pts), len(cts)
lefts = []
cache6 = {}
cache8 = {}
for pt in pts:
left, right = pt[:32], pt[32:]
fl = frombin(left)
lefts.append(fl)
for pt in pts[::256]:
left, right = pt[:32], pt[32:]
for skn, sk in enumerate(RANGE6):
blob1 = Fmy(right, sk, si=6-1)
blob2 = Fmy(right, sk, si=8-1)
cache6[pt,skn] = frombin(blob1)
cache8[pt,skn] = frombin(blob2)
counters = defaultdict(int)
total = defaultdict(int)
for ichunk in xrange(N):
print "chunkA", ichunk+1, "/", N
chunk = zip(pts[ichunk*512:ichunk*512+512], cts[ichunk*512:ichunk*512+512], lefts)
for skn6 in xrange(64):
for skn8 in xrange(64):
by_left = {}
for i, (pt, ct, left) in enumerate(chunk):
if i == 0 or i == 256:
blob1 = cache6[pt, skn6]
blob2 = cache8[pt, skn8]
func = blob1 ^ blob2 ^ left
if i < 256:
assert func not in by_left
by_left[func] = i
else:
func ^= 0x40000000
assert func in by_left
j = by_left[func]
counters[skn6,skn8] += 1 - (chunk[i][1] ^ chunk[j][1])
total[skn6,skn8] += 1
best = sorted(counters.items(), key=lambda (a, b): b)
best_cnt = best[-1][1]
print "best", best_cnt
skn68_cands = []
for a, b in best[-10:]:
print a, b
if b == best_cnt:
skn68_cands.append(a)
print "CANDS", skn68_cands
cands_sk_0_68 = skn68_cands
print "[+] candscands_sk_0_68", cands_sk_0_68
print
# recover up
for skn6, skn8 in skn68_cands[:1]:
total = defaultdict(int)
counters = defaultdict(int)
for ichunk in xrange(N):
# print "chunk", ichunk+1, "/", N
chunk = zip(pts[ichunk*512:ichunk*512+512], cts[ichunk*512:ichunk*512+512], lefts)
for skn5 in xrange(64):
by_left = {}
for i, (pt, ct, left) in enumerate(chunk):
if i == 0 or i == 256:
blob1 = cache6[pt, skn6]
blob2 = cache8[pt, skn8]
func = blob1 ^ blob2 ^ left
if i < 256:
assert func not in by_left
by_left[func] = i
else:
func ^= 0x40000000
assert func in by_left
j = by_left[func]
parity = 0
ct1 = ctsx[ichunk*512+i]
ct2 = ctsx[ichunk*512+j]
for ct in (ct1, ct2):
left, right = ct[:32], ct[32:]
func = Fmy(right, RANGE6[skn5], 4)
left = xor_bits(left, func)
parity ^= parity_mask(left, LMASK)
counters[skn5] += 1 ^ parity
total[skn5] += 1
best = sorted(counters.items(), key=lambda (a, b): b)
best_cnt = best[-1][1]
print "best", best_cnt
cands = []
for a, b in best[-10:]:
print a, b
if b == best_cnt:
cands.append(a)
print "cands", cands
assert len(cands) == 1
recovered_sk_7_5 = cands[0]
print "[+]", "recovered SK_7_5", recovered_sk_7_5
print
else:
recovered_sk_7_5 = SK7_5_int
cands_sk_0_68 = [(SK0_6_int, SK0_8_int)]
print "DEBUG 2"
if FLAGS[2]:
'''
00800002
1 [0, 0, 0, 0, 0, 0]
2 [0, 0, 0, 0, 0, 1] *
2
3 [0, 1, 0, 0, 0, 0] *
3
4 [0, 0, 0, 0, 0, 0]
5 [0, 0, 0, 0, 0, 0]
6 [0, 0, 0, 0, 0, 0]
7 [0, 0, 0, 0, 0, 0]
8 [0, 0, 0, 1, 0, 0] *
submask: 0x44094114
'''
print("STAGE 3")
submasks = []
i = mask = 0x44094114
while i >= 0:
submasks.append(str_to_bits(("%08x" % i).decode("hex")) + [0] * 32)
if i == 0:
break
i = (i - 1) & mask
ROUNDS = 8
N = 2**7
DELTA_L = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
DELTA = str_to_bits("\x40\x00\x00\x00" + "\x00\x80\x00\x02")
DELTA = xor_bits(DELTA[:32], DELTA_L) + DELTA[32:]
DELTA_L_int = int(bits_to_str(DELTA[:32]).encode("hex"), 16)
print "DELTA", bits_to_str(DELTA).encode("hex")
print "DELTA_L_int %08x" % DELTA_L_int
pts = []
for i in xrange(N):
# print("prep", i, "/", N)
if i % 2**3 == 0 and i:
print hex(i)
while True:
base = [randint(0, 1) for _ in xrange(64)]
for skn6, skn8 in cands_sk_0_68:
test1 = Fmy(base[32:], RANGE6[skn8], si=8-1)
test2 = Fmy(xor_bits(base[32:], DELTA[32:]), RANGE6[skn8], si=8-1)
difftest = xor_bits(test1, test2)
# difftest = int(bits_to_str(difftest).encode("hex"), 16)
# print "0x%08x" % difftest,
# difftest &= 0xffffffff ^ mask
# if 1:
# print "0x%08x," % difftest,
# print str_to_bits(("%08x" % difftest).decode("hex"))
if difftest != DELTA_L:
break
else:
break
pts_cur = []
for sub in submasks:
pts_cur.append(xor_bits(base, sub))
for pt in pts_cur[::]:
pts_cur.append(xor_bits(pt, DELTA))
pts += map(tuple, pts_cur)
# pt1 = pts_cur[0+100]
# pt2 = pts_cur[256+100]
# test1 = F(pt1[32:], SK[0])
# test2 = F(pt2[32:], SK[0])
# difftest = xor_bits(test1, test2)
# difftest = difftest = int(bits_to_str(difftest).encode("hex"), 16)
# print "diff %08x" % difftest, "by mask %08x" % (difftest & mask), "rest %08x" % (difftest & (0xffffffff ^ mask))
cts = []
ctsx = []
ptsx = []
for pt in pts:
pt = [pt[x-1] for x in IP_inv]
pt = bits_to_str(pt)
ptsx.append(pt)
ptsx = "".join(ptsx)
newct = encrypt(ptsx, KEY)
assert len(ptsx) == len(newct)
for i in xrange(0, len(newct), 8):
ct = newct[i:i+8]
bits = str_to_bits(ct)
bits = [bits[x-1] for x in IP_1_inv]
left = bits[:32]
right = bits[32:]
res = parity_mask(LMASK, left) ^ right[16]
cts.append(res)
ctsx.append(bits)
print len(pts), len(cts)
# 6,8 -> 2,3
lefts = []
cache6 = {}
cache8 = {}
for pt in pts:
left, right = pt[:32], pt[32:]
fl = frombin(left)
lefts.append(fl)
for pt in pts[::256]:
left, right = pt[:32], pt[32:]
for skn, sk in enumerate(RANGE6):
blob1 = Fmy(right, sk, si=2-1)
blob2 = Fmy(right, sk, si=3-1)
cache6[pt,skn] = frombin(blob1)
cache8[pt,skn] = frombin(blob2)
counters = defaultdict(int)
total = defaultdict(int)
for ichunk in xrange(N):
print "chunkB", ichunk+1, "/", N
chunk = zip(pts[ichunk*512:ichunk*512+512], cts[ichunk*512:ichunk*512+512], lefts)
for skn6 in xrange(64):
for skn8 in xrange(64):
by_left = {}
for i, (pt, ct, left) in enumerate(chunk):
if i == 0 or i == 256:
blob1 = cache6[pt, skn6]
blob2 = cache8[pt, skn8]
func = blob1 ^ blob2 ^ left
if i < 256:
assert func not in by_left
by_left[func] = i
else:
func ^= DELTA_L_int
assert func in by_left
j = by_left[func]
counters[skn6,skn8] += 1 - (chunk[i][1] ^ chunk[j][1])
total[skn6,skn8] += 1
best = sorted(counters.items(), key=lambda (a, b): b)
best_cnt = best[-1][1]
print "best", best_cnt
skn68_cands = []
for a, b in best[-20:]:
print a, b
if b == best_cnt:
skn68_cands.append(a)
print "CANDS", skn68_cands
cands_sk_0_23 = skn68_cands
print "[+] candscands_sk_0_23", cands_sk_0_23
print
print
else:
cands_sk_0_23 = [(SK0_2_int, SK0_3_int)]
print "DEBUG3"
if FLAGS[3]:
'''
00008200
1 [0, 0, 0, 0, 0, 0]
2 [0, 0, 0, 0, 0, 0]
3 [0, 0, 0, 0, 0, 0]
4 [0, 0, 0, 0, 0, 1] *
4
5 [0, 1, 0, 0, 0, 0] *
5
6 [0, 0, 0, 1, 0, 0] *
7 [0, 0, 0, 0, 0, 0]
8 [0, 0, 0, 0, 0, 0]
submask: 0xa14410c0
'''
print("STAGE 4")
submasks = []
i = mask = 0xa14410c0
while i >= 0:
submasks.append(str_to_bits(("%08x" % i).decode("hex")) + [0] * 32)
if i == 0:
break
i = (i - 1) & mask
ROUNDS = 8
N = 2**7
DELTA_L = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]
DELTA = str_to_bits("\x40\x00\x00\x00" + "\x00\x00\x82\x00")
DELTA = xor_bits(DELTA[:32], DELTA_L) + DELTA[32:]
DELTA_L_int = int(bits_to_str(DELTA[:32]).encode("hex"), 16)
print "DELTA", bits_to_str(DELTA).encode("hex")
print "DELTA_L_int %08x" % DELTA_L_int
pts = []
for i in xrange(N):
if i % 2**3 == 0 and i:
print hex(i)
while True:
base = [randint(0, 1) for _ in xrange(64)]
for skn6, skn8 in cands_sk_0_68:
test1 = Fmy(base[32:], RANGE6[skn6], si=6-1)
test2 = Fmy(xor_bits(base[32:], DELTA[32:]), RANGE6[skn6], si=6-1)
difftest = xor_bits(test1, test2)
# difftest = int(bits_to_str(difftest).encode("hex"), 16)
# print "0x%08x" % difftest,
# difftest &= 0xffffffff ^ mask
# if 1:
# print "0x%08x," % difftest,
# print str_to_bits(("%08x" % difftest).decode("hex"))
# print difftest
if difftest != DELTA_L:
break
else:
break
pts_cur = []
for sub in submasks:
pts_cur.append(xor_bits(base, sub))
for pt in pts_cur[::]:
pts_cur.append(xor_bits(pt, DELTA))
pts += map(tuple, pts_cur)
# pt1 = pts_cur[0+100]
# pt2 = pts_cur[256+100]
# test1 = F(pt1[32:], SK[0])
# test2 = F(pt2[32:], SK[0])
# difftest = xor_bits(test1, test2)
# difftest = difftest = int(bits_to_str(difftest).encode("hex"), 16)
# print "diff %08x" % difftest, "by mask %08x" % (difftest & mask), "rest %08x" % (difftest & (0xffffffff ^ mask))
cts = []
ctsx = []
ptsx = []
for pt in pts:
pt = [pt[x-1] for x in IP_inv]
pt = bits_to_str(pt)
ptsx.append(pt)
ptsx = "".join(ptsx)
newct = encrypt(ptsx, KEY)
assert len(ptsx) == len(newct)
for i in xrange(0, len(newct), 8):
ct = newct[i:i+8]
bits = str_to_bits(ct)
bits = [bits[x-1] for x in IP_1_inv]
left = bits[:32]
right = bits[32:]
res = parity_mask(LMASK, left) ^ right[16]
cts.append(res)
ctsx.append(bits)
print len(pts), len(cts)
# 6,8 -> 4,5
lefts = []
cache6 = {}
cache8 = {}
for pt in pts:
left, right = pt[:32], pt[32:]
fl = frombin(left)
lefts.append(fl)
for pt in pts[::256]:
left, right = pt[:32], pt[32:]
for skn, sk in enumerate(RANGE6):
blob1 = Fmy(right, sk, si=4-1)
blob2 = Fmy(right, sk, si=5-1)
cache6[pt,skn] = frombin(blob1)
cache8[pt,skn] = frombin(blob2)
counters = defaultdict(int)
total = defaultdict(int)
for ichunk in xrange(N):
print "chunkC", ichunk+1, "/", N
chunk = zip(pts[ichunk*512:ichunk*512+512], cts[ichunk*512:ichunk*512+512], lefts)
for skn6 in xrange(64):
for skn8 in xrange(64):
by_left = {}
for i, (pt, ct, left) in enumerate(chunk):
if i == 0 or i == 256:
blob1 = cache6[pt, skn6]
blob2 = cache8[pt, skn8]
func = blob1 ^ blob2 ^ left
if i < 256:
assert func not in by_left
by_left[func] = i
else:
func ^= DELTA_L_int
assert func in by_left
j = by_left[func]
counters[skn6,skn8] += 1 - (chunk[i][1] ^ chunk[j][1])
total[skn6,skn8] += 1
best = sorted(counters.items(), key=lambda (a, b): b)
best_cnt = best[-1][1]
print "best", best_cnt
skn68_cands = []
for a, b in best[-20:]:
print a, b
if b == best_cnt:
skn68_cands.append(a)
print "CANDS", skn68_cands
cands_sk_0_45 = skn68_cands
print "[+] candscands_sk_0_45", cands_sk_0_45
print
print
else:
cands_sk_0_45 = [(SK0_4_int, SK0_5_int)]
print "DEBUG3"
# recovered_sk_7_1
# recovered_sk_7_5
# cands_sk_0_1 = cands
# cands_sk_0_68
def setkey(i, j, val):
global MASTER_KEY
j -= 1
inds = SUBKEY_POS[i][j*6:j*6+6]
# print i, j, inds
for ii, ind in enumerate(inds):
MASTER_KEY[ind] = (val >> (5 - ii)) & 1
print("STAGE 5 - final bruteforce")
pt = "omgwtfxz"
ct = encrypt(pt, KEY)
print "pt = %s" % pt.encode("hex")
print "ct = %s" % ct.encode("hex")
print KEY.encode("hex")
bits = str_to_bits(KEY)
key = frombin(bits)
print "SECRET KEY %016x" % key
for recovered_sk_0_1 in cands_sk_0_1:
for recovered_sk_0_6, recovered_sk_0_8 in cands_sk_0_68:
for recovered_sk_0_2, recovered_sk_0_3 in cands_sk_0_23:
for recovered_sk_0_4, recovered_sk_0_5 in cands_sk_0_45:
MASTER_KEY = [None] * 64
for i in xrange(7, 64, 8):
MASTER_KEY[i] = 0
setkey(0, 1, recovered_sk_0_1)
setkey(0, 2, recovered_sk_0_2)
setkey(0, 3, recovered_sk_0_3)
setkey(0, 4, recovered_sk_0_4)
setkey(0, 5, recovered_sk_0_5)
setkey(0, 6, recovered_sk_0_6)
setkey(0, 8, recovered_sk_0_8)
setkey(7, 1, recovered_sk_7_1)
setkey(7, 5, recovered_sk_7_5)
# print Counter(MASTER_KEY)
mask_unknown = frombin([int(v == None) for v in MASTER_KEY])
mask_known = frombin([int(v != None) for v in MASTER_KEY])
val_known = frombin([int(v) if v != None else 0 for v in MASTER_KEY ])
print Counter(MASTER_KEY)
print "known = %016x & %016x" % (val_known, mask_known)
print "unknown = %016x" % mask_unknown
print "secret check %016x (if local)" % (key & mask_known), (key & mask_known) == val_known
print "brute..."
test_key = mask_unknown
while test_key > 0:
kkkey = test_key | val_known
kkkey = ("%016x" % kkkey).decode("hex")
test_ct = encrypt_local(pt, kkkey)
if test_ct == ct:
print "PWND", kkkey.encode("hex")
print "pt", pt.encode("hex")
print "ct", ct.encode("hex")
print "test_ct", test_ct.encode("hex")
f1.write("flagflag" + kkkey)
f1.flush()
quit()
test_key = (test_key - 1) & mask_unknown
print
print "no luck"
# print `f.buf`
# f.interact()
#flag{but_th3_litt1e_sticky_leave5_and_tHe_pRec1ous_t0mbs_and_the_b1Ue_sky_ANd_The_woman_you_loVe_How_will_you_lIVe_h0w_wi1l_yoU_love_them_wiTh_5uch_a_h3ll_in_YouR_h34rt_and_y0ur_he4d__How__Can__You__}
# 5min locally
prepare:
mkfifo p1 p2
precomp:
./stage1_precomp
du -hs dump*
# 17G dump0
# 17G dump1
# 17G dump2
# 17G dump3
compile:
python3 -m pip install -U bint
# generate tables
python3 mydes.py
g++ stage1_precomp.cpp -std=c++14 -O3 -o stage1_precomp
g++ stage2.cpp -std=c++14 -O3 -o stage2
import struct
from bint import Bin
INITIAL_PERMUTATION = (
57, 49, 41, 33, 25, 17, 9, 1,
59, 51, 43, 35, 27, 19, 11, 3,
61, 53, 45, 37, 29, 21, 13, 5,
63, 55, 47, 39, 31, 23, 15, 7,
56, 48, 40, 32, 24, 16, 8, 0,
58, 50, 42, 34, 26, 18, 10, 2,
60, 52, 44, 36, 28, 20, 12, 4,
62, 54, 46, 38, 30, 22, 14, 6,
)
FINAL_PERMUTATION = (
39, 7, 47, 15, 55, 23, 63, 31,
38, 6, 46, 14, 54, 22, 62, 30,
37, 5, 45, 13, 53, 21, 61, 29,
36, 4, 44, 12, 52, 20, 60, 28,
35, 3, 43, 11, 51, 19, 59, 27,
34, 2, 42, 10, 50, 18, 58, 26,
33, 1, 41, 9, 49, 17, 57, 25,
32, 0, 40, 8, 48, 16, 56, 24,
)
EXPANSION = (
31, 0, 1, 2, 3, 4,
3, 4, 5, 6, 7, 8,
7, 8, 9, 10, 11, 12,
11, 12, 13, 14, 15, 16,
15, 16, 17, 18, 19, 20,
19, 20, 21, 22, 23, 24,
23, 24, 25, 26, 27, 28,
27, 28, 29, 30, 31, 0,
)
iEXPANSION = [EXPANSION.index(i) for i in range(32)]
PERMUTATION = (
15, 6, 19, 20, 28, 11, 27, 16,
0, 14, 22, 25, 4, 17, 30, 9,
1, 7, 23, 13, 31, 26, 2, 8,
18, 12, 29, 5, 21, 10, 3, 24,
)
iPERMUTATION = tuple(PERMUTATION.index(i) for i in range(32))
PERMUTED_CHOICE1 = (
56, 48, 40, 32, 24, 16, 8,
0, 57, 49, 41, 33, 25, 17,
9, 1, 58, 50, 42, 34, 26,
18, 10, 2, 59, 51, 43, 35,
62, 54, 46, 38, 30, 22, 14,
6, 61, 53, 45, 37, 29, 21,
13, 5, 60, 52, 44, 36, 28,
20, 12, 4, 27, 19, 11, 3,
)
PERMUTED_CHOICE2 = (
13, 16, 10, 23, 0, 4,
2, 27, 14, 5, 20, 9,
22, 18, 11, 3, 25, 7,
15, 6, 26, 19, 12, 1,
40, 51, 30, 36, 46, 54,
29, 39, 50, 44, 32, 47,
43, 48, 38, 55, 33, 52,
45, 41, 49, 35, 28, 31,
)
SUBSTITUTION_BOX = (
(
14, 4, 13, 1, 2, 15, 11, 8, 3, 10, 6, 12, 5, 9, 0, 7,
0, 15, 7, 4, 14, 2, 13, 1, 10, 6, 12, 11, 9, 5, 3, 8,
4, 1, 14, 8, 13, 6, 2, 11, 15, 12, 9, 7, 3, 10, 5, 0,
15, 12, 8, 2, 4, 9, 1, 7, 5, 11, 3, 14, 10, 0, 6, 13,
),
(
15, 1, 8, 14, 6, 11, 3, 4, 9, 7, 2, 13, 12, 0, 5, 10,
3, 13, 4, 7, 15, 2, 8, 14, 12, 0, 1, 10, 6, 9, 11, 5,
0, 14, 7, 11, 10, 4, 13, 1, 5, 8, 12, 6, 9, 3, 2, 15,
13, 8, 10, 1, 3, 15, 4, 2, 11, 6, 7, 12, 0, 5, 14, 9,
),
(
10, 0, 9, 14, 6, 3, 15, 5, 1, 13, 12, 7, 11, 4, 2, 8,
13, 7, 0, 9, 3, 4, 6, 10, 2, 8, 5, 14, 12, 11, 15, 1,
13, 6, 4, 9, 8, 15, 3, 0, 11, 1, 2, 12, 5, 10, 14, 7,
1, 10, 13, 0, 6, 9, 8, 7, 4, 15, 14, 3, 11, 5, 2, 12,
),
(
7, 13, 14, 3, 0, 6, 9, 10, 1, 2, 8, 5, 11, 12, 4, 15,
13, 8, 11, 5, 6, 15, 0, 3, 4, 7, 2, 12, 1, 10, 14, 9,
10, 6, 9, 0, 12, 11, 7, 13, 15, 1, 3, 14, 5, 2, 8, 4,
3, 15, 0, 6, 10, 1, 13, 8, 9, 4, 5, 11, 12, 7, 2, 14,
),
(
2, 12, 4, 1, 7, 10, 11, 6, 8, 5, 3, 15, 13, 0, 14, 9,
14, 11, 2, 12, 4, 7, 13, 1, 5, 0, 15, 10, 3, 9, 8, 6,
4, 2, 1, 11, 10, 13, 7, 8, 15, 9, 12, 5, 6, 3, 0, 14,
11, 8, 12, 7, 1, 14, 2, 13, 6, 15, 0, 9, 10, 4, 5, 3,
),
(
12, 1, 10, 15, 9, 2, 6, 8, 0, 13, 3, 4, 14, 7, 5, 11,
10, 15, 4, 2, 7, 12, 9, 5, 6, 1, 13, 14, 0, 11, 3, 8,
9, 14, 15, 5, 2, 8, 12, 3, 7, 0, 4, 10, 1, 13, 11, 6,
4, 3, 2, 12, 9, 5, 15, 10, 11, 14, 1, 7, 6, 0, 8, 13,
),
(
4, 11, 2, 14, 15, 0, 8, 13, 3, 12, 9, 7, 5, 10, 6, 1,
13, 0, 11, 7, 4, 9, 1, 10, 14, 3, 5, 12, 2, 15, 8, 6,
1, 4, 11, 13, 12, 3, 7, 14, 10, 15, 6, 8, 0, 5, 9, 2,
6, 11, 13, 8, 1, 4, 10, 7, 9, 5, 0, 15, 14, 2, 3, 12,
),
(
13, 2, 8, 4, 6, 15, 11, 1, 10, 9, 3, 14, 5, 0, 12, 7,
1, 15, 13, 8, 10, 3, 7, 4, 12, 5, 6, 11, 0, 14, 9, 2,
7, 11, 4, 1, 9, 12, 14, 2, 0, 6, 10, 13, 15, 3, 5, 8,
2, 1, 14, 7, 4, 10, 8, 13, 15, 12, 9, 0, 3, 5, 6, 11,
),
)
SBOXES = [
[sbox[i6 & 0x20 | (i6 & 0x01) << 4 | (i6 & 0x1e) >> 1] for i6 in range(64)]
for sbox in SUBSTITUTION_BOX
]
ROTATES = (
1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1,
)
def rotate_left(i28, k):
return i28 << k & 0x0fffffff | i28 >> 28 - k
def permute(data, bits, mapper):
ret = 0
for i, v in enumerate(mapper):
if data & 1 << bits - 1 - v:
ret |= 1 << len(mapper) - 1 - i
return ret
def derive_keys(key):
key, = struct.unpack(">Q", key)
next_key = permute(key, 64, PERMUTED_CHOICE1)
l, r = next_key >> 28, next_key & 0x0fffffff
for bits in ROTATES:
l, r = rotate_left(l, bits), rotate_left(r, bits)
yield permute((l << 28) | r, 56, PERMUTED_CHOICE2)
INIT = [
INITIAL_PERMUTATION[i]
for i in list(iPERMUTATION) + [32 + j for j in iPERMUTATION]
]
FINI = [
INIT.index(i) for i in range(64)
]
EXPERM = [
PERMUTATION[i] for i in EXPANSION
]
PERMEXP = [
EXPANSION[i] for i in PERMUTATION
]
def f_preexp(block, fix=None):
ret = 0
for i, box in enumerate(SBOXES):
i6 = block >> (42 - i * 6) & 0x3f
if fix is None or (7 - i) in fix:
ret = (ret << 4) | box[i6]
else:
ret = (ret << 4)
ret = permute(ret, 32, EXPERM)
return ret
def f_preexp_tab(block):
ret = 0
mask = (1 << B) - 1
assert 0 <= block < 2**48
for i in range(48 // B):
ret |= TAB[i][block & mask]
block >>= B
assert block == 0
return ret
def block_to_lr(block):
block = Bin(block, n=64).tuple
block = Bin([block[i] for i in INITIAL_PERMUTATION])
l, r = block.split(2)
l = Bin([l[i] for i in EXPANSION]).int
r = Bin([r[i] for i in EXPANSION]).int
return l, r
def encrypt_block(block, subkeys, inverse=False):
# real DES computation
if inverse:
subkeys = reversed(subkeys)
block = Bin(block, n=64).tuple
block = Bin([block[i] for i in INITIAL_PERMUTATION])
l, r = block.split(2)
l = Bin([l[i] for i in EXPANSION]).int
r = Bin([r[i] for i in EXPANSION]).int
# print("l = 0x%x" % l)
# print("r = 0x%x" % r)
for rno, key in enumerate(subkeys):
y = f_preexp_tab(r ^ key) # or f_preexp
l, r = r, l ^ y
l, r = r, l
# print("l = 0x%x" % l)
# print("r = 0x%x" % r)
l = Bin(l, 48)
r = Bin(r, 48)
l = Bin([l[i] for i in iEXPANSION])
r = Bin([r[i] for i in iEXPANSION])
block = Bin.concat(l, r, n=32)
return Bin([block[i] for i in FINAL_PERMUTATION]).bytes
def derive_keys_from_seed(seed):
key = DoubleExtend(seed)
key, = struct.unpack(">Q", key)
next_key = permute(key, 64, PERMUTED_CHOICE1)
l, r = next_key >> 28, next_key & 0x0fffffff
for bits in ROTATES:
l, r = rotate_left(l, bits), rotate_left(r, bits)
yield permute((l << 28) | r, 56, PERMUTED_CHOICE2)
if __name__ == '__main__':
from bint import Bin
from Crypto.Cipher import DES
TAB = []
f = open("tab.h", "w")
B = 6
print(f"uint64_t TABF[{48//B}][1 << {B}] = {{", file=f)
for i in range(48 // B):
row = []
print("{", file=f, end=" ")
for x in range(2**B):
block = x << (i * B)
ret = f_preexp(block, fix=(i,))
row.append(ret)
print("0x%x," % ret, file=f, end="")
print("},", file=f)
TAB.append(row)
print("};", file=f)
from TColors import DoubleExtend, DES, os
TABKS = []
TABKS_BIT = []
f = open("tab.h", "a")
ZERO = list(derive_keys_from_seed(b"\x00\x00\x00\x00"))[:8]
BB = 8
# TABKS[bit index][round no]
print(f"uint64_t TABKS_ZERO[8] = {{", file=f)
for v in ZERO:
print("0x%x," % v, file=f, end="")
print("\n};", file=f)
# bytewise KS (from scratch)
print(f"uint64_t TABKS[{32 // BB}][{1 << BB}][8] = {{", file=f)
for i in range(32 // BB):
print("{", file=f)
row = []
for x in range(2**BB):
seed = Bin(x << (i * BB), 32).bytes
ret = list(derive_keys_from_seed(seed))[:8]
ret = [a ^ b for a, b in zip(ret, ZERO)]
print("{", file=f, end=" ")
for v in ret:
print("0x%x," % v, file=f, end="")
print("},", file=f)
row.append(ret)
print("},", file=f)
TABKS.append(row)
print("};", file=f)
# bitwise KS (to flip bits in the key fast)
TABKS_BIT = []
print(f"uint64_t TABKS_BIT[32][8] = {{", file=f)
for i in range(32):
seed = Bin(1 << i, 32).bytes
ret = list(derive_keys_from_seed(seed))[:8]
ret = [a ^ b for a, b in zip(ret, ZERO)]
print("{", file=f, end=" ")
for v in ret:
print("0x%x," % v, file=f, end="")
print("},", file=f)
TABKS_BIT.append(ret)
print("};", file=f)
import sys, hashlib
from sock import Sock
from mydes import block_to_lr
from bint import Bin
f = Sock("66.172.10.203 13371", timeout=36000)
# f = Sock("94.232.173.132 12431", timeout=36000)
f.read_until(b"such that ")
h = f.read_until(b"(")[:-1]
print("h", h)
f.read_until(b" = ")
suf = f.read_until(b" and len(X) = ").split()[0]
print("suf", suf)
suf = suf.split()[0].decode()
l = f.read_line().strip()
print("l", l)
h = h.decode("ascii")
l = int(l)
print("hash", h, "suf", suf, "len", l)
hf = getattr(hashlib, h)
for s in range(2**40):
s = (b"%x" % s).rjust(l, b"x")
if hf(s).hexdigest().endswith(suf):
break
print("ok", s)
f.send_line(s)
def encrypt(pt):
if not isinstance(pt, bytes):
if isinstance(pt[0], int):
pt = bytes(pt)
else:
pt = b"".join(bytes(v) for v in pt)
f.send_line("E")
f.read_until("| Send plaintext in hex format")
f.read_line()
f.send_line(pt.hex())
s = f.read_line()
ct = bytes.fromhex(s.decode().strip())
return [ct[i:i+8] for i in range(0, len(ct), 8)]
def xor(a, b):
return bytes(aa ^ bb for aa, bb in zip(a, b))
def get23(block):
iv, ziv, zy, _pad = encrypt(block + block)
z = xor(ziv, iv)
y = xor(zy, z)
return y, z
res = encrypt([0] * 2**14)
print("res ok", len(res))
'''
iv 0 0
DES DES
x x
DES DES
y y
DES DES
z z
iv z + iv z + y
'''
y, z = get23([0] * 8)
yQ, zQ = get23([1] * 8)
print("GOT y", y)
l1, r1 = block_to_lr(y)
print("please run:\nstage2.cpp %012x %012x" % (l1, r1))
from TColors import DES, DoubleExtend
while True:
print("input?")
k1, k2 = input().split()
k1 = bytes.fromhex(k1)
k2 = bytes.fromhex(k2)
dk1 = DES(DoubleExtend(k1))
dk2 = DES(DoubleExtend(k2))
x = b"\x00" * 8
x = dk1.encrypt(x)
x = dk2.encrypt(x)
print("enc 0 ->", x.hex(), "vs", y.hex())
a = b"\x01" * 8
a = dk1.encrypt(a)
a = dk2.encrypt(a)
print("enc 1 ->", a.hex(), "vs", yQ.hex())
if x == y and a == yQ:
print("phew match")
break
else:
print("ouch")
def oracle3(yy):
if not isinstance(yy, bytes):
assert len(yy) == 8
yy = bytes(yy)
block = yy
block = dk2.decrypt(block)
block = dk1.decrypt(block)
y, z = get23(block)
print("check?", y, yy, y == yy)
return z
def oracle3batch(pt):
assert isinstance(pt, bytes)
print("prepro...")
ys = [pt[i:i+8] for i in range(0, len(pt), 8)]
xs = [dk1.decrypt(dk2.decrypt(y)) for y in ys]
print("encrypt...")
iv, *zss, tail = encrypt(b"".join(xs))
print("encrypt ok")
assert len(zss) == len(xs)
for i in range(len(zss)):
if i == 0:
zss[i] = xor(zss[i], iv)
else:
zss[i] = xor(zss[i], ys[i-1])
print("post pro ok")
return b"".join(zss)
z = oracle3([1, 2, 3, 4, 5, 6, 7, 8])
print("z", z)
# dunno why pipes
# $ mkfifo p1 p2
f1 = open("p1", "rb")
f2 = open("p2", "wb")
from struct import unpack
from TColors import DES, DoubleExtend
print("go proxy loop")
itr = 0
while True:
n = f1.read(8)
if n == b"flagflag":
key = f1.read(8)
print("GOT KEY", key, key.hex())
break
n, = unpack("<Q", n)
print("got signal", n, "query #", itr)
s = b""
while len(s) < n:
s += f1.read(n - len(s))
print("read all", n)
t = oracle3batch(s)
f2.write(t)
f2.flush()
print("sent all", n)
itr += 1
try:
for mask in range(256):
curkey = list(key)
for i in range(8):
curkey[i] ^= (mask & 1)
mask >>= 1
curkey = bytes(curkey)
print("S")
print(curkey.hex())
f.send_line("S")
f.send_line(curkey.hex())
print(f.read_one())
except Exception as err:
print("err", err)
f.interact()
'''
MitM stage2:
real 23m36.570s
user 15m51.992s
sys 3m19.286s
Differential linear:
real 14m15.204s
user 2m46.636s
sys 0m0.260s
b'| Send third round key in hex format.\n| Congratz! You got the flag: ASIS{T0_ki55_iN_car5_4nd_d0wnTown_baRs___W4s_aLL_w3_need3d____Y0u_dr3w_s7ars_ar0Und_mY_sc4rs___Bu7_nOw_I_am_bl33ding!}\n'
'''
#include "common.cpp"
int main(int argc, char *argv[]) {
l1 = 0;
r1 = 0;
// RELAX = 2 => 2**2 tables of 16 GiB (RAM requirement), 2**2 slowdown
// RELAX = 3 => 2**3 tables of 8 GiB (RAM requirement), 2**3 slowdown
int RELAX = 2;
FOR(k1hi, 0, 1ull << RELAX) {
unordered_set<uint64_t> collisions;
vector<pair<uint64_t, uint64_t>> values1(1ull << (32 - RELAX));
printf("size %d\n", sizeof(values1[0]));
printf("K1hi %lu STEP 1\n", k1hi);
uint64_t curkeys[8];
expand(k1hi << (32 - RELAX), curkeys);
uint64_t prev = 0;
FOR(index, 1ull, 1ull << (32 - RELAX)) {
uint32_t k1lo = index ^ (index >> 1);
int bit_pos = bitpos(k1lo ^ prev);
prev = k1lo;
FORN(i, 8) curkeys[i] ^= TABKS_BIT[bit_pos][i];
uint64_t l = l1;
uint64_t r = r1;
encrypt(l, r, curkeys);
uint64_t h = (l << 32) ^ r;
values1[k1lo] = {h, k1lo | (k1hi << (32 - RELAX))};
}
printf("sorting\n");
sort(whole(values1));
printf("sorting ok\n");
char path[4096];
sprintf(path, "dump%d", k1hi);
printf("dumping to %s\n", path);
FILE *fd = fopen(path, "w");
fwrite((void *)values1.data(), 1ull << (32 - RELAX + 3 + 1), 1, fd);
fclose(fd);
printf("dumping ok\n");
}
return 0;
}
#include "common.cpp"
int main(int argc, char *argv[]) {
l1 = 0;
r1 = 0;
l2 = strtoull(argv[1], 0, 16);
r2 = strtoull(argv[2], 0, 16);
// unrelated to stage1_precomp RELAX
// RELAX = 2 => 2**2 tables of 16 GiB (RAM requirement), 2**2 slowdown
// RELAX = 3 => 2**3 tables of 8 GiB (RAM requirement), 2**3 slowdown
const int RELAX1 = 2;
const int NELEM1 = 26;
int NCHUNKS = (1ull << (32 - RELAX1 - NELEM1));
printf("NCHUNKS per 1/%d files: %d\n", 1<<RELAX1, NCHUNKS);
vector<pair<uint64_t, uint64_t>> values1(1 << NELEM1); // 1024 MB
int RELAX = 3;
vector<pair<uint64_t, uint64_t>> values2(1ull << (32 - RELAX));
FORN(k2hi, 1ull << RELAX) {
printf("k2hi %llu / %llu\n", k2hi, 1ull << RELAX);
uint64_t curkeys[8];
expand(k2hi << (32 - RELAX), curkeys);
uint64_t prev = 0;
// Grey codes to flip bits
FOR(index, 1ull, 1ull << (32 - RELAX)) {
uint32_t k2lo = index ^ (index >> 1);
int bit_pos = bitpos(k2lo ^ prev);
prev = k2lo;
FORN(i, 8) curkeys[i] ^= TABKS_BIT[bit_pos][i];
uint64_t l = l2;
uint64_t r = r2;
decrypt(l, r, curkeys);
uint64_t h = (l << 32) ^ r;
values2[k2lo] = {h, k2lo | (k2hi << (32 - RELAX))};
}
printf("sorting\n");
sort(whole(values2));
printf("sorting ok\n");
FORN(k1hi, 4) {
char path[4096];
sprintf(path, "dump%d", k1hi);
printf("sieving from %s\n", path);
FILE *fd = fopen(path, "r");
int todo = NCHUNKS;
auto it2 = values2.begin();
auto it2end = values2.end();
uint64_t nfalse = 0;
while (todo-- && it2 != it2end) {
assert(1 == fread((void *)values1.data(), values1.size() * 16, 1, fd));
auto it1 = values1.begin();
auto it1end = values1.end();
while (it1 != it1end && it2 != it2end) {
if ((*it1).first == (*it2).first) {
uint32_t k1 = (*it1).second;
uint32_t k2 = (*it2).second;
uint64_t l = l1;
uint64_t r = r1;
enc(l, r, k1);
enc(l, r, k2);
if (l == l2 && r == r2) {
printf("FOUND %08x %08x\n", k1, k2);
}
else {
// printf("False positive %08x %08x: %012lx %012lx vs %012lx %012lx\n", k1, k2, l, r, l2, r2);
nfalse++;
}
it1++;
}
else if (*it1 < *it2) {
it1++;
}
else {
it2++;
}
}
printf("+");
fflush(stdout);
}
printf("False positives: %llu\n", nfalse);
}
}
return 0;
}
#!/usr/bin/python
import os, functools, struct
KEY_SIZE = 8
BLOCK_SIZE = 8
IP = [
58, 50, 42, 34, 26, 18, 10, 2,
60, 52, 44, 36, 28, 20, 12, 4,
62, 54, 46, 38, 30, 22, 14, 6,
64, 56, 48, 40, 32, 24, 16, 8,
57, 49, 41, 33, 25, 17, 9 ,1,
59, 51, 43, 35, 27, 19, 11, 3,
61, 53, 45, 37, 29, 21, 13, 5,
63, 55, 47, 39, 31, 23, 15, 7,
]
IP_INV = [
40, 8, 48, 16, 56, 24, 64, 32,
39, 7, 47, 15, 55, 23, 63, 31,
38, 6, 46, 14, 54, 22, 62, 30,
37, 5, 45, 13, 53, 21, 61, 29,
36, 4, 44, 12, 52, 20, 60, 28,
35, 3, 43, 11, 51, 19, 59, 27,
34, 2, 42, 10, 50, 18, 58, 26,
33, 1, 41, 9, 49, 17, 57,25,
]
E = [
32, 1, 2, 3, 4, 5,
4, 5, 6, 7, 8, 9,
8, 9, 10, 11, 12, 13,
12, 13, 14, 15, 16, 17,
16, 17, 18, 19, 20, 21,
20, 21, 22, 23, 24, 25,
24, 25, 26, 27, 28, 29,
28, 29, 30, 31, 32, 1,
]
PC1_C = [
57, 49, 41, 33, 25, 17, 9,
1, 58, 50, 42, 34, 26, 18,
10, 2, 59, 51, 43, 35, 27,
19, 11, 3, 60, 52, 44, 36,
]
PC1_D = [
63, 55, 47, 39, 31, 23, 15,
7, 62, 54, 46, 38, 30, 22,
14, 6, 61, 53, 45, 37, 29,
21, 13, 5, 28, 20, 12, 4,
]
PC2 = [
14, 17, 11, 24, 1, 5,
3, 28, 15, 6, 21, 10,
23, 19, 12, 4, 26, 8,
16, 7, 27, 20, 13, 2,
41, 52, 31, 37, 47, 55,
30, 40, 51, 45, 33, 48,
44, 49, 39, 56, 34, 53,
46, 42, 50, 36, 29, 32,
]
KS_SHIFTS = [1,1,2,2,2,2,2,2,1,2,2,2,2,2,2,1]
P = [
16, 7, 20, 21,
29, 12, 28, 17,
1, 15, 23, 26,
5, 18, 31, 10,
2, 8, 24, 14,
32, 27, 3, 9,
19, 13, 30, 6,
22, 11, 4, 25,
]
S1 = [
[14, 4, 13, 1, 2, 15, 11, 8, 3, 10, 6, 12, 5, 9, 0, 7],
[ 0, 15, 7, 4, 14, 2, 13, 1, 10, 6, 12, 11, 9, 5, 3, 8],
[ 4, 1, 14, 8, 13, 6, 2, 11, 15, 12, 9, 7, 3, 10, 5, 0],
[15, 12, 8, 2, 4, 9, 1, 7, 5, 11, 3, 14, 10, 0, 6, 13],
]
S2 = [
[15, 1, 8, 14, 6, 11, 3, 4, 9, 7, 2, 13, 12, 0, 5, 10],
[3, 13, 4, 7, 15, 2, 8, 14, 12, 0, 1, 10, 6, 9, 11, 5],
[0, 14, 7, 11, 10, 4, 13, 1, 5, 8, 12, 6, 9, 3, 2, 15],
[13, 8, 10, 1, 3, 15, 4, 2, 11, 6, 7, 12, 0, 5, 14, 9],
]
S3 = [
[10, 0, 9, 14, 6, 3, 15, 5, 1, 13, 12, 7, 11, 4, 2, 8],
[13, 7, 0, 9, 3, 4, 6, 10, 2, 8, 5, 14, 12, 11, 15, 1],
[13, 6, 4, 9, 8, 15, 3, 0, 11, 1, 2, 12, 5, 10, 14, 7],
[1, 10, 13, 0, 6, 9, 8, 7, 4, 15, 14, 3, 11, 5, 2, 12],
]
S4 = [
[7, 13, 14, 3, 0, 6, 9, 10, 1, 2, 8, 5, 11, 12, 4, 15],
[13, 8, 11, 5, 6, 15, 0, 3, 4, 7, 2, 12, 1, 10, 14, 9],
[10, 6, 9, 0, 12, 11, 7, 13, 15, 1, 3, 14, 5, 2, 8, 4],
[3, 15, 0, 6, 10, 1, 13, 8, 9, 4, 5, 11, 12, 7, 2, 14],
]
S5 = [
[2, 12, 4, 1, 7, 10, 11, 6, 8, 5, 3, 15, 13, 0, 14, 9],
[14, 11, 2, 12, 4, 7, 13, 1, 5, 0, 15, 10, 3, 9, 8, 6],
[4, 2, 1, 11, 10, 13, 7, 8, 15, 9, 12, 5, 6, 3, 0, 14],
[11, 8, 12, 7, 1, 14, 2, 13, 6, 15, 0, 9, 10, 4, 5, 3],
]
S6 = [
[12, 1, 10, 15, 9, 2, 6, 8, 0, 13, 3, 4, 14, 7, 5, 11],
[10, 15, 4, 2, 7, 12, 9, 5, 6, 1, 13, 14, 0, 11, 3, 8],
[9, 14, 15, 5, 2, 8, 12, 3, 7, 0, 4, 10, 1, 13, 11, 6],
[4, 3, 2, 12, 9, 5, 15, 10, 11, 14, 1, 7, 6, 0, 8, 13],
]
S7 = [
[4, 11, 2, 14, 15, 0, 8, 13, 3, 12, 9, 7, 5, 10, 6, 1],
[13, 0, 11, 7, 4, 9, 1, 10, 14, 3, 5, 12, 2, 15, 8, 6],
[1, 4, 11, 13, 12, 3, 7, 14, 10, 15, 6, 8, 0, 5, 9, 2],
[6, 11, 13, 8, 1, 4, 10, 7, 9, 5, 0, 15, 14, 2, 3, 12],
]
S8 = [
[13, 2, 8, 4, 6, 15, 11, 1, 10, 9, 3, 14, 5, 0, 12, 7],
[1, 15, 13, 8, 10, 3, 7, 4, 12, 5, 6, 11, 0, 14, 9, 2],
[7, 11, 4, 1, 9, 12, 14, 2, 0, 6, 10, 13, 15, 3, 5, 8],
[2, 1, 14, 7, 4, 10, 8, 13, 15, 12, 9, 0, 3, 5, 6, 11],
]
SBOXES = [S1, S2, S3, S4, S5, S6, S7, S8]
def Pad(data, bs = 8):
r = len(data) % bs
add_len = bs - r if r != 0 else bs
add = b'\x80' + b'\x00'*(add_len-1)
return data + add
def Unpad(data, bs = 8):
i = 1
while data[-i] == b'\x00': i += 1
return data[:-i]
def Xor(b1, b2):
"""Xors two bit vectors together."""
return [x ^ y for x, y in zip(b1, b2)]
def BlockXor(b1, b2):
"""Xors two bytesarrays together."""
return bytes([a ^ b for a, b in zip(b1, b2)])
def Concat(*vectors):
"""Concats vectors."""
return functools.reduce(lambda x, y: x + y, vectors, [])
def Str2Bits(s):
"""Converts a string to a vector of bits."""
assert (isinstance(s, bytes))
def Char2Bits(num):
bits = bin(num)[2:]
bits = '0' * (8 - len(bits)) + bits
return [int(b) for b in bits]
return Concat(* [Char2Bits(c) for c in s])
def Bits2Str(v):
"""Converts a vector of bits to a string."""
def Bits2Char(byte):
return struct.pack('>B', int(''.join([str(b) for b in byte]), 2))
return b''.join([Bits2Char(v[8 * i:8 * i + 8]) for i in range(len(v) // 8)])
def DoubleExtend(data):
bits = Str2Bits(data)
out = []
for i in range(64):
if i % 2 == 0:
out.append(1 - bits[i // 2])
else:
out.append(bits[i // 2])
return Bits2Str(out)
def Expand(v):
"""Expands 32bits into 48 bits."""
assert (len(v) == 32)
return [v[E[i] - 1] for i in range(48)]
def LeftShift(v, t=1):
"""Left shitfs (rotates) a vector of bits t times."""
return v[t:] + v[:t]
def KeyScheduler(key):
"""Yields round keys."""
assert (len(key) == 64)
# Only 56 bits are used. A bit in each byte is reserved for pairity checks.
C = [key[PC1_C[i] - 1] for i in range(28)]
D = [key[PC1_D[i] - 1] for i in range(28)]
for ri in range(8):
C = LeftShift(C, KS_SHIFTS[ri])
D = LeftShift(D, KS_SHIFTS[ri])
CD = Concat(C, D)
ki = [CD[PC2[i] - 1] for i in range(48)]
yield ki
def CipherFunction(key, inp):
"""Single confusion-diffusion step."""
assert (len(key) == 48)
assert (len(inp) == 32)
# Confusion step.
res = Xor(Expand(inp), key)
sbox_out = []
for si in range(48 // 6):
sbox_inp = res[6 * si:6 * si + 6]
sbox = SBOXES[si]
row = (int(sbox_inp[0]) << 1) + int(sbox_inp[-1])
col = int(''.join([str(b) for b in sbox_inp[1:5]]), 2)
bits = bin(sbox[row][col])[2:]
bits = '0' * (4 - len(bits)) + bits
sbox_out += [int(b) for b in bits]
# Diffusion step.
res = sbox_out
res = [res[P[i] - 1] for i in range(32)]
return res
class DES(object):
def __init__(self, key):
if isinstance(key, bytes):
self.key = Str2Bits(key)
else:
self.key = key
assert (len(self.key) == 64)
def encrypt(self, plaintext):
if isinstance(plaintext, bytes):
plaintext = Str2Bits(plaintext)
# Initial permutation.
plaintext = [plaintext[IP[i] - 1] for i in range(64)]
L, R = plaintext[:32], plaintext[32:]
# Feistel network.
for ki in KeyScheduler(self.key):
L, R = R, Xor(L, CipherFunction(ki, R))
# Final permutation.
ciphertext = Concat(R, L)
ciphertext = [ciphertext[IP_INV[i] - 1] for i in range(64)]
return Bits2Str(ciphertext)
def decrypt(self, ciphertext):
if isinstance(ciphertext, bytes):
ciphertext = Str2Bits(ciphertext)
# Initial permutation.
ciphertext = [ciphertext[IP[i] - 1] for i in range(64)]
L, R = ciphertext[:32], ciphertext[32:]
# Feistel network.
for ki in reversed(list(KeyScheduler(self.key))):
L, R = R, Xor(L, CipherFunction(ki, R))
# Final permutation.
plaintext = Concat(R, L)
plaintext = [plaintext[IP_INV[i] - 1] for i in range(64)]
return Bits2Str(plaintext)
class Color(object):
def __init__(self, k1, k2, k3):
self.E1 = DES(k1)
self.E2 = DES(k2)
self.E3 = DES(k3)
def encrypt(self, data, iv):
ciphertext = iv
data = Pad(data)
C = iv
for i in range(0, len(data), 8):
A1 = self.E1.encrypt(data[i:i+8])
A2 = self.E2.encrypt(A1)
A3 = self.E3.encrypt(A2)
ciphertext += BlockXor(A3, C)
C = A2
return ciphertext.hex()
def decrypt(self, data, iv):
plaintext = b''
C = iv
for i in range(0, len(data), 8):
A3 = BlockXor(data[i:i+8], C)
A2 = self.E3.decrypt(A3)
A1 = self.E2.decrypt(A2)
plaintext += self.E1.decrypt(A1)
C = A2
return Unpad(plaintext).hex()
class ThreeColors(object):
def __init__(self):
self.k1 = DoubleExtend(os.urandom(4))
self.k2 = DoubleExtend(os.urandom(4))
self.k3 = os.urandom(8)
self.color = Color(self.k1, self.k2, self.k3)
def get_keys(self):
return self.k1, self.k2, self.k3
if __name__ == '__main__':
k1 = DoubleExtend(bytes.fromhex("8a1ec18c"))
k2 = DoubleExtend(bytes.fromhex("ffc2120b"))
k3 = bytes.fromhex("da10be06e822cc4e")
# print(k1, k1.hex())
# print(k2, k2.hex())
# print(k3, k3.hex())
c = Color(k1, k2, k3)
pt = bytes.fromhex("0011223344556677")
ct = bytes.fromhex("0dc5b4f4ef0ab3db685907aa56962c8e6aeb051ddd7e6a10")
# print(c.decrypt(data=ct[8:], iv=ct[:8]))
for mask in range(256):
key = list(bytes.fromhex("da10be06e822cc4e"))
for i in range(8):
key[i] ^= (mask & 1)
mask >>= 1
key = bytes(key)
# print("S")
print(key.hex())
k1 = DoubleExtend(bytes.fromhex("8a1ec18c"))
k2 = DoubleExtend(bytes.fromhex("ffc2120b"))
k3 = key
c = Color(k1, k2, k3)
pt = bytes.fromhex("0011223344556677")
ct = bytes.fromhex("0dc5b4f4ef0ab3db685907aa56962c8e6aeb051ddd7e6a10")
print(mask, c.decrypt(data=ct[8:], iv=ct[:8]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment