Create a gist now

Instantly share code, notes, and snippets.

Embed
What would you like to do?
0CTF 2018 Quals - zer0SPN (Crypto 550)
'''
In the challenge we have a "toy block cipher". It is an SPN cipher with:
- 4 rounds
- 8 8-bit S-Boxes (64-bit block)
- bit permutations as linear layer
We are given 2^16 random plaintext/ciphertext pairs.
On contrast with the zer0TC challenge, the bit permutation is strong and provides full diffusion.
The S-Box is weak both differentially and linearly.
Since we have known plaintexts, the way to go is linear cryptanalysis.
We shall attack the first round in order to get the master key and avoid need of key-schedule reversal.
First, we need to find good 3-round linear trails. This can be done using various algorithms/tools.
For example:
Masks after first round: [64, 0, 0, 0, 0, 0, 0, 0],
Masks on ciphertexts: [242, 0, 0, 0, 0, 0, 242, 0],
Bias: 2^-5.675513
We need to have bias > 2^-8 because we have 2^16 data.
It actually is easier if the bias is around 2^-6,
then the right key byte will be the top candidate in our list with high probability.
The attack procedure:
1. Guess the first key byte (k) of the master key
2. Partially encrypt the first byte of all plaintexts: x' = S(x^k).
3. Compute linear product: c = scalar(x', mask)
4. Compute the bias of all c (i.e. how dis-balanced is the distribution of 0/1).
The right key byte should be in the top candidates sorted by the bias.
After we recover a couple of key bytes,
we can use linear trails which have more active S-Boxes in the first round.
The constraint is only that we have to guess only one extra key byte each time.
Finally, we get the flag: flag{48667ec1a5fb3383}
'''
import math
from zer0SPN import zer0SPN, sbox
sboxinv = map(sbox.index, range(256))
def hw(x):
if not x: return 0
return (x & 1) + hw(x >> 1)
scalar = [
[hw(a & b) & 1 for a in xrange(256)] for b in xrange(256)
]
if 1:
f = open("data")
pairs = []
for i in xrange(2**16):
pt = map(ord, f.read(8))
ct = map(ord, f.read(8))
pairs.append((pt, ct))
TEST_ROUND_KEYS = [[0]*8]*5
else:
from os import urandom
key = "abcdefgh"
c = zer0SPN(key)
pairs = []
for _ in xrange(2**16):
plaintext = bytearray(urandom(8))
ciphertext = c.encrypt(plaintext)
pairs.append((tuple(plaintext), tuple(ciphertext)))
TEST_ROUND_KEYS = tuple(
tuple(c.roundkey[i:i+8])
for i in xrange(0, len(c.roundkey), 8)
)
def sl(bias):
try:
return "2^%f" % math.log(2, bias)
except:
# print bias
return "??"
KEY = [None] * 8
# Trails with single-bytes on the plaintext side (after first round)
# To verify correlation we need to guess one key byte
masks = [
# -5.675513
[64, 0, 0, 0, 0, 0, 0, 0],
[242, 0, 0, 0, 0, 0, 242, 0],
# -5.798370
[0, 64, 0, 0, 0, 0, 0, 0],
[138, 0, 0, 0, 0, 0, 138, 0],
# -7.445192
# [0, 0, 0, 0, 64, 0, 0, 0],
# [224, 0, 0, 0, 0, 0, 224, 0],
]
for i in xrange(0, len(masks), 2):
inmask, outmask = masks[i:i+2]
inposes = [i for i in xrange(8) if inmask[i]]
outposes = [i for i in xrange(8) if outmask[i]]
inpos = inposes[0]
main_inmask = inmask[inpos]
results = []
for k in xrange(256):
cor = [0, 0], [0, 0]
for pt, ct in pairs:
out = 0
for i in outposes:
out ^= scalar[outmask[i]][ct[i]]
v = pt[inpos]
v ^= k
v = sbox[v]
inp = scalar[main_inmask][v]
cor[inp][out] += 1
a, b = cor[0]
a, b = min(a, b), max(a, b)
bias = abs(float(a) / (len(pairs)/2.0) - 0.5)
results.append((bias, k))
results.sort()
KEY[inpos] = results[-1][1]
print "INPOS", inpos, "%02x" % TEST_ROUND_KEYS[0][inpos]
for bias, k in results[-10:]:
print "%02x" % k, sl(bias)
print
print "KEY", KEY
print "KEY MID", KEY
# Trails with multiple bytes on the plaintext side (after first round)
# To verify correlation we need to guess one key byte + use already known
# (order of trails is important)
masks = [
# 4 -5.482868
[0, 64, 0, 0, 64, 0, 0, 0],
[106, 0, 0, 0, 0, 0, 106, 0],
# 6 -4.754170
[64, 64, 0, 0, 0, 0, 64, 0],
[178, 0, 0, 0, 0, 0, 178, 0],
# 2 -5.653971
[0, 64, 64, 0, 0, 0, 64, 0],
[68, 0, 0, 0, 0, 0, 68, 0],
# 5 -5.745579
[0, 64, 64, 0, 0, 64, 0, 0],
[25, 0, 0, 0, 0, 0, 25, 0],
# 3 -5.725353
[194, 194, 0, 194, 194, 194, 194, 0],
[0, 0, 0, 0, 117, 0, 0, 0],
# 7 -5.353585
[0, 0, 0, 25, 25, 0, 0, 25],
[0, 130, 0, 0, 0, 0, 0, 0],
]
for i in xrange(0, len(masks), 2):
inmask, outmask = masks[i:i+2]
inposes = [i for i in xrange(8) if inmask[i] and KEY[i] is not None]
outposes = [i for i in xrange(8) if outmask[i]]
inpos = [i for i in xrange(8) if inmask[i] and KEY[i] is None]
assert len(inpos) == 1
inpos = inpos[0]
main_inmask = inmask[inpos]
results = []
for k in xrange(256):
cor = [0, 0], [0, 0]
for pt, ct in pairs:
out = 0
for i in outposes:
out ^= scalar[outmask[i]][ct[i]]
v = pt[inpos]
v ^= k
v = sbox[v]
inp = scalar[main_inmask][v]
for i in inposes:
v = pt[i]
v ^= KEY[i]
v = sbox[v]
inp ^= scalar[inmask[i]][v]
cor[inp][out] += 1
a, b = cor[0]
a, b = min(a, b), max(a, b)
bias = abs(float(a) / (len(pairs)/2.0) - 0.5)
# print "%02x" % k, cor, sl(bias)
results.append((bias, k))
results.sort()
KEY[inpos] = results[-1][1]
print "INPOS", inpos, "%02x" % TEST_ROUND_KEYS[0][inpos]
for bias, k in results[-10:]:
print "%02x" % k, sl(bias)
print "KEY", KEY
print
print "flag{%s}" % "".join("%02x" % c for c in KEY)
#!/usr/bin/env python
# coding=utf-8
rcon = [0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a]
sbox = [62, 117, 195, 179, 20, 210, 41, 66, 116, 178, 152, 143, 75, 105, 254, 1, 158, 95, 101, 175, 191, 166, 36, 24, 50, 39, 190, 120, 52, 242, 182, 185, 61, 225, 140, 38, 150, 80, 19, 109, 246, 252, 40, 13, 65, 236, 124, 186, 214, 86, 235, 100, 97, 49, 197, 154, 176, 199, 253, 69, 88, 112, 139, 77, 184, 45, 133, 104, 15, 54, 177, 244, 160, 169, 82, 148, 73, 30, 229, 35, 79, 137, 157, 180, 248, 163, 241, 231, 81, 94, 165, 9, 162, 233, 18, 85, 217, 84, 7, 55, 63, 171, 56, 118, 237, 132, 136, 22, 90, 221, 103, 161, 205, 11, 255, 14, 122, 47, 71, 201, 99, 220, 83, 74, 173, 76, 144, 16, 155, 126, 60, 96, 44, 234, 17, 215, 107, 138, 159, 183, 251, 3, 198, 0, 89, 170, 131, 151, 219, 29, 230, 32, 187, 125, 134, 64, 12, 202, 164, 247, 25, 223, 222, 119, 174, 67, 147, 146, 206, 51, 243, 53, 121, 239, 68, 130, 70, 203, 211, 111, 108, 113, 8, 106, 57, 240, 21, 93, 142, 238, 167, 5, 128, 72, 189, 192, 193, 92, 10, 204, 87, 145, 188, 172, 224, 226, 207, 27, 218, 48, 33, 28, 123, 6, 37, 59, 4, 102, 114, 91, 23, 209, 34, 42, 2, 196, 141, 208, 181, 245, 43, 78, 213, 216, 232, 46, 98, 26, 212, 58, 115, 194, 200, 129, 227, 249, 127, 149, 135, 228, 31, 153, 250, 156, 168, 110]
ptable = [
0, 8, 16, 24, 32, 40, 48, 56,
1, 9, 17, 25, 33, 41, 49, 57,
2, 10, 18, 26, 34, 42, 50, 58,
3, 11, 19, 27, 35, 43, 51, 59,
4, 12, 20, 28, 36, 44, 52, 60,
5, 13, 21, 29, 37, 45, 53, 61,
6, 14, 22, 30, 38, 46, 54, 62,
7, 15, 23, 31, 39, 47, 55, 63
]
def s2b(s):
return map(int, format(int(str(s).encode('hex'), 16), '0{}b'.format(8*len(s))))
def b2s(b):
return bytearray.fromhex(format(reduce(lambda x,y: 2*x+y, b), '0{}x'.format(len(b)/4)))
def addkey(a, b):
global flag
return bytearray(i^j for i,j in zip(a, b))
def substitute(a):
return bytearray(sbox[i] for i in a)
def permutation(a):
assert len(a) == 8
bits = s2b(a)
bits = [bits[ptable[i]] for i in xrange(64)]
return b2s(bits)
class zer0SPN(object):
'''0ops Substitution–Permutation Network'''
def __init__(self, key, key_size=8, rounds=4):
assert len(key) == key_size
self.key = key
self.key_size = key_size
self.rounds = rounds
self.key_schedule()
def key_schedule(self):
roundkey = bytearray(self.key)
tmp = roundkey[-4:]
for i in xrange(1, self.rounds+1):
tmp = tmp[1:] + tmp[:1]
tmp = bytearray(sbox[i] for i in tmp)
tmp[0] ^= rcon[i]
for j in range(self.key_size/4):
for k in range(4):
tmp[k] ^= roundkey[-self.key_size+k]
roundkey += tmp
self.roundkey = roundkey
def get_roundkey(self, k):
assert k <= self.rounds
return self.roundkey[self.key_size*k:self.key_size*(k+1)]
def encrypt(self, plain):
assert len(plain) == self.key_size
block = bytearray(plain)
for i in xrange(self.rounds):
block = addkey(block, self.get_roundkey(i))
block = substitute(block)
if i != self.rounds - 1:
# Permutation in the last round is of no purpose.
block = permutation(block)
block = addkey(block, self.get_roundkey(i+1))
return block
if __name__ == '__main__':
from secret import secret
from os import urandom
from struct import pack
print "Your flag is flag{%s}" % secret.encode('hex')
f = open('data', 'wb')
for _ in xrange(65536):
c = zer0SPN(secret)
plaintext = bytearray(urandom(8))
f.write(pack('8B', *plaintext))
ciphertext = c.encrypt(plaintext)
f.write(pack('8B', *ciphertext))
f.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment