Skip to content

Instantly share code, notes, and snippets.

@zrax-x
Last active December 10, 2020 07:00
Show Gist options
  • Save zrax-x/de52a31c360bda960a08c90b6c7a4501 to your computer and use it in GitHub Desktop.
Save zrax-x/de52a31c360bda960a08c90b6c7a4501 to your computer and use it in GitHub Desktop.
太湖杯2020,aegis,差分攻击
'''
ref form https://sasdf.github.io/ctf/writeup/2020/google/crypto/oracle/ &
https://github.com/sasdf/ctf/blob/master/writeup/2020/google/crypto/oracle/_files/solve.py
'''
import aegis
import base64
import sys
import aes
import itertools
import re
from collections import defaultdict
from telnetlib import Telnet
def print_block(s, n=1):
for i in range(0, len(s), 16*n):
ss = [s[i+z*16:i+z*16+16].hex() for z in range(n)]
print(f's_{i:04d}', *ss)
print('')
def gen_diff(i, d, b=b'\0'*16*6, s=16):
pt = b'\0' * i * s
pt += bytes([d] * s)
pt += b'\0' * 16*6
pt = aes.xor_bytes(pt, b)
return pt[:16*6]
def inv_diff(d1):
d1 = aes.bytes2matrix(d1)
aes.inv_mix_columns(d1)
aes.inv_shift_rows(d1)
d1 = aes.matrix2bytes(d1)
return d1
def inv_state(d0, d1):
d1 = aes.xor_bytes(d0, d1)
d1 = aes.bytes2matrix(d1)
aes.inv_mix_columns(d1)
aes.inv_shift_rows(d1)
aes.inv_sub_bytes(d1)
d1 = aes.matrix2bytes(d1)
return d1
def inv_r(d1):
d1 = aes.bytes2matrix(d1)
aes.inv_mix_columns(d1)
aes.inv_shift_rows(d1)
aes.inv_sub_bytes(d1)
d1 = aes.matrix2bytes(d1)
return d1
mc_table = {}
for i in range(16):
for c in range(1, 256):
s = bytearray(16)
s[i] = c
s = aes.bytes2matrix(s)
aes.shift_rows(s)
aes.mix_columns(s)
s = aes.matrix2bytes(s)
for p in range(16):
if s[p] != 0:
assert (p, s[p]) not in mc_table
mc_table[(i, p, s[p])] = c
sub_table = defaultdict(list)
for i in range(256):
d = aes.s_box[i] ^ aes.s_box[i^1]
sub_table[d].append(i)
remote = Telnet('122.112.209.168', 10090)
def solve_phase1():
iv = remote.read_until(b'\n')
def enc(aad, pt):
remote.write(base64.b64encode(pt).replace(b'\n', b'') + b'\n')
remote.write(base64.b64encode(aad).replace(b'\n', b'') + b'\n')
ct = base64.b64decode(remote.read_until(b'\n').strip())
tag = base64.b64decode(remote.read_until(b'\n').strip())
return ct
def _recover_s(d1, d2):
d1, d2 = inv_diff(d1), inv_diff(d2)
res = []
for y1, y2 in zip(d1, d2):
for c in range(256):
if (aes.s_box[c] ^ aes.s_box[c^1]) == y1 and (aes.s_box[c] ^ aes.s_box[c^2]) == y2:
res.append(c)
break
else:
raise ValueError('Not found')
return bytes(res)
def recover_s(i, base):
ct1 = enc(b'', gen_diff(i, 1))
ct2 = enc(b'', gen_diff(i, 2))
d1 = aes.xor_bytes(ct1, base)
d2 = aes.xor_bytes(ct2, base)
off = (i + 2) * 16
s0 = _recover_s(d1[off:off+16], d2[off:off+16])
# s4 = _recover_s(d1[off+16:off+16], d2[off+16:off+16])
return s0
def _and(a, b):
assert len(a) == len(b)
return bytes(x & y for x, y in zip(a, b))
base = enc(b'', b'\0' * 16*6)
print('0')
e0 = recover_s(0, base)
print(e0)
print('1')
e1 = recover_s(1, base)
rs4_1 = aes.xor_bytes(e1, e0)
print('2')
e2 = recover_s(2, base)
rs4_2 = aes.xor_bytes(e2, e1)
remote.read_until(b'leak:')
s2_1 = base64.b64decode(remote.read_until(b'\n').strip())
print(s2_1)
s4_1 = inv_r(rs4_1)
s4_2 = inv_r(rs4_2)
rs3_1 = aes.xor_bytes(s4_1, s4_2)
s3_1 = inv_r(rs3_1)
s1_1 = aes.xor_bytes(aes.xor_bytes(base[16:32], s4_1), _and(s2_1, s3_1))
s0_1 = e0
S = [s0_1, s1_1, s2_1, s3_1, s4_1]
remote.write(base64.b64encode(b''.join(S)).replace(b'\n', b'') + b'\n')
result = remote.read_until(b'\n').strip()
print(result)
solve_phase1()
print(remote.read_all())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment