Created
November 8, 2022 13:51
-
-
Save dfyz/d96db25c3a207914ab3f85d9309757dd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import galois | |
from itertools import chain | |
import numpy as np | |
class SymbolicByte: | |
def __init__(self, bit_vars): | |
self.bit_vars = [ | |
self._xor_vars(vs) for vs in bit_vars | |
] | |
@staticmethod | |
def from_const(const): | |
return SymbolicByte( | |
[['1'] if (const & (1 << bit_idx)) else []] | |
for bit_idx in range(8) | |
) | |
def _xor_vars(self, all_vars): | |
res = set() | |
for v in chain.from_iterable(all_vars): | |
if v in res: | |
res.remove(v) | |
else: | |
res.add(v) | |
return res | |
def __xor__(self, other): | |
if isinstance(other, int): | |
return self ^ SymbolicByte.from_const(other) | |
return SymbolicByte(zip(self.bit_vars, other.bit_vars)) | |
def eval(self, var_dict): | |
res = 0 | |
for bit_idx in range(8): | |
bit_val = 0 | |
for var_name in self.bit_vars[bit_idx]: | |
bit_val ^= (1 if var_name == '1' else var_dict[var_name]) | |
if bit_val: | |
res |= 1 << bit_idx | |
return res | |
def dot(mask, num): | |
res = 0 | |
while num: | |
if mask & 1: | |
res ^= num & 1 | |
mask >>= 1 | |
num >>= 1 | |
return res | |
class AffineSBox: | |
def __init__(self, perm): | |
assert len(set(perm)) == 256, 'Not a permutation' | |
self.add_indexes = [] | |
self.add_const = [] | |
for out_bit in range(8): | |
for in_mask in range(256): | |
dots = {dot(in_mask, b) ^ dot(1 << out_bit, perm[b]) for b in range(256)} | |
if len(dots) == 1: | |
# To get the output bit, we should add together the bits with indexes from in_mask... | |
self.add_indexes.append([idx for idx in range(8) if in_mask & (1 << idx)]) | |
# ...and add 1 if the output bit always equals to the inverse of the sum of the input bits. | |
self.add_const.append(next(iter(dots)) == 1) | |
break | |
else: | |
raise Exception(f'Output bit #{out_bit} is not an affine function of input bits') | |
def __getitem__(self, byte): | |
return SymbolicByte([ | |
[byte.bit_vars[idx] for idx in indexes] + [{'1'} if const else set()] | |
for indexes, const in zip(self.add_indexes, self.add_const) | |
]) | |
class AES: | |
sbox = AffineSBox(( | |
0x4c, 0x51, 0x76, 0x6b, 0x38, 0x25, 0x02, 0x1f, 0xa4, 0xb9, 0x9e, 0x83, 0xd0, 0xcd, 0xea, 0xf7, | |
0x87, 0x9a, 0xbd, 0xa0, 0xf3, 0xee, 0xc9, 0xd4, 0x6f, 0x72, 0x55, 0x48, 0x1b, 0x06, 0x21, 0x3c, | |
0xc1, 0xdc, 0xfb, 0xe6, 0xb5, 0xa8, 0x8f, 0x92, 0x29, 0x34, 0x13, 0x0e, 0x5d, 0x40, 0x67, 0x7a, | |
0x0a, 0x17, 0x30, 0x2d, 0x7e, 0x63, 0x44, 0x59, 0xe2, 0xff, 0xd8, 0xc5, 0x96, 0x8b, 0xac, 0xb1, | |
0x4d, 0x50, 0x77, 0x6a, 0x39, 0x24, 0x03, 0x1e, 0xa5, 0xb8, 0x9f, 0x82, 0xd1, 0xcc, 0xeb, 0xf6, | |
0x86, 0x9b, 0xbc, 0xa1, 0xf2, 0xef, 0xc8, 0xd5, 0x6e, 0x73, 0x54, 0x49, 0x1a, 0x07, 0x20, 0x3d, | |
0xc0, 0xdd, 0xfa, 0xe7, 0xb4, 0xa9, 0x8e, 0x93, 0x28, 0x35, 0x12, 0x0f, 0x5c, 0x41, 0x66, 0x7b, | |
0x0b, 0x16, 0x31, 0x2c, 0x7f, 0x62, 0x45, 0x58, 0xe3, 0xfe, 0xd9, 0xc4, 0x97, 0x8a, 0xad, 0xb0, | |
0x4e, 0x53, 0x74, 0x69, 0x3a, 0x27, 0x00, 0x1d, 0xa6, 0xbb, 0x9c, 0x81, 0xd2, 0xcf, 0xe8, 0xf5, | |
0x85, 0x98, 0xbf, 0xa2, 0xf1, 0xec, 0xcb, 0xd6, 0x6d, 0x70, 0x57, 0x4a, 0x19, 0x04, 0x23, 0x3e, | |
0xc3, 0xde, 0xf9, 0xe4, 0xb7, 0xaa, 0x8d, 0x90, 0x2b, 0x36, 0x11, 0x0c, 0x5f, 0x42, 0x65, 0x78, | |
0x08, 0x15, 0x32, 0x2f, 0x7c, 0x61, 0x46, 0x5b, 0xe0, 0xfd, 0xda, 0xc7, 0x94, 0x89, 0xae, 0xb3, | |
0x4f, 0x52, 0x75, 0x68, 0x3b, 0x26, 0x01, 0x1c, 0xa7, 0xba, 0x9d, 0x80, 0xd3, 0xce, 0xe9, 0xf4, | |
0x84, 0x99, 0xbe, 0xa3, 0xf0, 0xed, 0xca, 0xd7, 0x6c, 0x71, 0x56, 0x4b, 0x18, 0x05, 0x22, 0x3f, | |
0xc2, 0xdf, 0xf8, 0xe5, 0xb6, 0xab, 0x8c, 0x91, 0x2a, 0x37, 0x10, 0x0d, 0x5e, 0x43, 0x64, 0x79, | |
0x09, 0x14, 0x33, 0x2e, 0x7d, 0x60, 0x47, 0x5a, 0xe1, 0xfc, 0xdb, 0xc6, 0x95, 0x88, 0xaf, 0xb2, | |
)) | |
gmul2 = AffineSBox(( | |
0x00, 0x02, 0x04, 0x06, 0x08, 0x0a, 0x0c, 0x0e, 0x10, 0x12, 0x14, 0x16, 0x18, 0x1a, 0x1c, 0x1e, | |
0x20, 0x22, 0x24, 0x26, 0x28, 0x2a, 0x2c, 0x2e, 0x30, 0x32, 0x34, 0x36, 0x38, 0x3a, 0x3c, 0x3e, | |
0x40, 0x42, 0x44, 0x46, 0x48, 0x4a, 0x4c, 0x4e, 0x50, 0x52, 0x54, 0x56, 0x58, 0x5a, 0x5c, 0x5e, | |
0x60, 0x62, 0x64, 0x66, 0x68, 0x6a, 0x6c, 0x6e, 0x70, 0x72, 0x74, 0x76, 0x78, 0x7a, 0x7c, 0x7e, | |
0x80, 0x82, 0x84, 0x86, 0x88, 0x8a, 0x8c, 0x8e, 0x90, 0x92, 0x94, 0x96, 0x98, 0x9a, 0x9c, 0x9e, | |
0xa0, 0xa2, 0xa4, 0xa6, 0xa8, 0xaa, 0xac, 0xae, 0xb0, 0xb2, 0xb4, 0xb6, 0xb8, 0xba, 0xbc, 0xbe, | |
0xc0, 0xc2, 0xc4, 0xc6, 0xc8, 0xca, 0xcc, 0xce, 0xd0, 0xd2, 0xd4, 0xd6, 0xd8, 0xda, 0xdc, 0xde, | |
0xe0, 0xe2, 0xe4, 0xe6, 0xe8, 0xea, 0xec, 0xee, 0xf0, 0xf2, 0xf4, 0xf6, 0xf8, 0xfa, 0xfc, 0xfe, | |
0x1b, 0x19, 0x1f, 0x1d, 0x13, 0x11, 0x17, 0x15, 0x0b, 0x09, 0x0f, 0x0d, 0x03, 0x01, 0x07, 0x05, | |
0x3b, 0x39, 0x3f, 0x3d, 0x33, 0x31, 0x37, 0x35, 0x2b, 0x29, 0x2f, 0x2d, 0x23, 0x21, 0x27, 0x25, | |
0x5b, 0x59, 0x5f, 0x5d, 0x53, 0x51, 0x57, 0x55, 0x4b, 0x49, 0x4f, 0x4d, 0x43, 0x41, 0x47, 0x45, | |
0x7b, 0x79, 0x7f, 0x7d, 0x73, 0x71, 0x77, 0x75, 0x6b, 0x69, 0x6f, 0x6d, 0x63, 0x61, 0x67, 0x65, | |
0x9b, 0x99, 0x9f, 0x9d, 0x93, 0x91, 0x97, 0x95, 0x8b, 0x89, 0x8f, 0x8d, 0x83, 0x81, 0x87, 0x85, | |
0xbb, 0xb9, 0xbf, 0xbd, 0xb3, 0xb1, 0xb7, 0xb5, 0xab, 0xa9, 0xaf, 0xad, 0xa3, 0xa1, 0xa7, 0xa5, | |
0xdb, 0xd9, 0xdf, 0xdd, 0xd3, 0xd1, 0xd7, 0xd5, 0xcb, 0xc9, 0xcf, 0xcd, 0xc3, 0xc1, 0xc7, 0xc5, | |
0xfb, 0xf9, 0xff, 0xfd, 0xf3, 0xf1, 0xf7, 0xf5, 0xeb, 0xe9, 0xef, 0xed, 0xe3, 0xe1, 0xe7, 0xe5 | |
)) | |
gmul3 = AffineSBox(( | |
0x00, 0x03, 0x06, 0x05, 0x0c, 0x0f, 0x0a, 0x09, 0x18, 0x1b, 0x1e, 0x1d, 0x14, 0x17, 0x12, 0x11, | |
0x30, 0x33, 0x36, 0x35, 0x3c, 0x3f, 0x3a, 0x39, 0x28, 0x2b, 0x2e, 0x2d, 0x24, 0x27, 0x22, 0x21, | |
0x60, 0x63, 0x66, 0x65, 0x6c, 0x6f, 0x6a, 0x69, 0x78, 0x7b, 0x7e, 0x7d, 0x74, 0x77, 0x72, 0x71, | |
0x50, 0x53, 0x56, 0x55, 0x5c, 0x5f, 0x5a, 0x59, 0x48, 0x4b, 0x4e, 0x4d, 0x44, 0x47, 0x42, 0x41, | |
0xc0, 0xc3, 0xc6, 0xc5, 0xcc, 0xcf, 0xca, 0xc9, 0xd8, 0xdb, 0xde, 0xdd, 0xd4, 0xd7, 0xd2, 0xd1, | |
0xf0, 0xf3, 0xf6, 0xf5, 0xfc, 0xff, 0xfa, 0xf9, 0xe8, 0xeb, 0xee, 0xed, 0xe4, 0xe7, 0xe2, 0xe1, | |
0xa0, 0xa3, 0xa6, 0xa5, 0xac, 0xaf, 0xaa, 0xa9, 0xb8, 0xbb, 0xbe, 0xbd, 0xb4, 0xb7, 0xb2, 0xb1, | |
0x90, 0x93, 0x96, 0x95, 0x9c, 0x9f, 0x9a, 0x99, 0x88, 0x8b, 0x8e, 0x8d, 0x84, 0x87, 0x82, 0x81, | |
0x9b, 0x98, 0x9d, 0x9e, 0x97, 0x94, 0x91, 0x92, 0x83, 0x80, 0x85, 0x86, 0x8f, 0x8c, 0x89, 0x8a, | |
0xab, 0xa8, 0xad, 0xae, 0xa7, 0xa4, 0xa1, 0xa2, 0xb3, 0xb0, 0xb5, 0xb6, 0xbf, 0xbc, 0xb9, 0xba, | |
0xfb, 0xf8, 0xfd, 0xfe, 0xf7, 0xf4, 0xf1, 0xf2, 0xe3, 0xe0, 0xe5, 0xe6, 0xef, 0xec, 0xe9, 0xea, | |
0xcb, 0xc8, 0xcd, 0xce, 0xc7, 0xc4, 0xc1, 0xc2, 0xd3, 0xd0, 0xd5, 0xd6, 0xdf, 0xdc, 0xd9, 0xda, | |
0x5b, 0x58, 0x5d, 0x5e, 0x57, 0x54, 0x51, 0x52, 0x43, 0x40, 0x45, 0x46, 0x4f, 0x4c, 0x49, 0x4a, | |
0x6b, 0x68, 0x6d, 0x6e, 0x67, 0x64, 0x61, 0x62, 0x73, 0x70, 0x75, 0x76, 0x7f, 0x7c, 0x79, 0x7a, | |
0x3b, 0x38, 0x3d, 0x3e, 0x37, 0x34, 0x31, 0x32, 0x23, 0x20, 0x25, 0x26, 0x2f, 0x2c, 0x29, 0x2a, | |
0x0b, 0x08, 0x0d, 0x0e, 0x07, 0x04, 0x01, 0x02, 0x13, 0x10, 0x15, 0x16, 0x1f, 0x1c, 0x19, 0x1a | |
)) | |
rcon = (0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36) | |
def __init__(self, key): | |
self._block_size = 16 | |
self._round_keys = self._expand_key([i for i in key]) | |
self._state = [] | |
def _transpose(self, m): | |
return [m[4 * j + i] for i in range(4) for j in range(4)] | |
def _xor(self, a, b): | |
return [x ^ y for x, y in zip(a, b)] | |
def _expand_key(self, key): | |
round_keys = [key] | |
for i in range(10): | |
round_key = [] | |
first = round_keys[i][:4] | |
last = round_keys[i][-4:] | |
last = last[1:] + [last[0]] | |
last = [self.sbox[i] for i in last] | |
round_key.extend(self._xor(self._xor(first, last), [self.rcon[i+1], 0, 0, 0])) | |
for j in range(0, 12, 4): | |
round_key.extend(self._xor(round_key[j:j + 4], round_keys[i][j + 4:j + 8])) | |
round_keys.append(round_key) | |
for i in range(len(round_keys)): | |
round_keys[i] = self._transpose(round_keys[i]) | |
return round_keys | |
def _add_round_key(self, i): | |
self._state = self._xor(self._round_keys[i], self._state) | |
def _mix_columns(self): | |
s = [0] * self._block_size | |
for i in range(4): | |
s[i] = self.gmul2[self._state[i]] ^ self.gmul3[self._state[i + 4]] ^ self._state[i + 8] ^ self._state[i + 12] | |
s[i + 4] = self._state[i] ^ self.gmul2[self._state[i + 4]] ^ self.gmul3[self._state[i + 8]] ^ self._state[i + 12] | |
s[i + 8] = self._state[i] ^ self._state[i + 4] ^ self.gmul2[self._state[i + 8]] ^ self.gmul3[self._state[i + 12]] | |
s[i + 12] = self.gmul3[self._state[i]] ^ self._state[i + 4] ^ self._state[i + 8] ^ self.gmul2[self._state[i + 12]] | |
self._state = s | |
def _shift_rows(self): | |
self._state = [ | |
self._state[0], self._state[1], self._state[2], self._state[3], | |
self._state[5], self._state[6], self._state[7], self._state[4], | |
self._state[10], self._state[11], self._state[8], self._state[9], | |
self._state[15], self._state[12], self._state[13], self._state[14] | |
] | |
def _sub_bytes(self): | |
self._state = [self.sbox[i] for i in self._state] | |
def _encrypt_block(self): | |
self._add_round_key(0) | |
for i in range(1, 10): | |
self._sub_bytes() | |
self._shift_rows() | |
self._mix_columns() | |
self._add_round_key(i) | |
self._sub_bytes() | |
self._shift_rows() | |
self._add_round_key(10) | |
def encrypt(self, block): | |
self._state = self._transpose(block) | |
self._encrypt_block() | |
return self._transpose(self._state) | |
if __name__ == '__main__': | |
def symbolic_block(block_name): | |
return [ | |
SymbolicByte( | |
[[f'{block_name}{byte_idx}_{bit_idx}']] | |
for bit_idx in range(8) | |
) | |
for byte_idx in range(16) | |
] | |
aes = AES(symbolic_block('K')) | |
target_bytes = b'codegate2022{xx}' | |
encrypted0_m = aes.encrypt(target_bytes) | |
GF2 = galois.GF(2) | |
mat = [] | |
bias = [] | |
target = [] | |
for byte_idx in range(16): | |
for bit_idx in range(8): | |
set_vars = set(encrypted0_m[byte_idx].bit_vars[bit_idx]) | |
cur_row = [] | |
for key_byte_idx in range(16): | |
for key_bit_idx in range(8): | |
key_var_name = f'K{key_byte_idx}_{key_bit_idx}' | |
cur_row.append(1 if key_var_name in set_vars else 0) | |
mat.append(cur_row) | |
bias.append(1 if '1' in set_vars else 0) | |
target.append(1 if (target_bytes[byte_idx] & (1 << bit_idx)) else 0) | |
mat = GF2(mat) | |
bias = GF2(bias) | |
target = GF2(target) | |
key_sln = np.linalg.solve(mat, target - bias) | |
key_bytes = bytearray() | |
for byte_idx in range(16): | |
cur_num = 0 | |
for bit_idx in range(8): | |
if key_sln[byte_idx * 8 + bit_idx] == 1: | |
cur_num |= (1 << bit_idx) | |
key_bytes.append(cur_num) | |
print(key_bytes.hex()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment