Skip to content

Instantly share code, notes, and snippets.

@vient
Last active July 1, 2021 13:19
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vient/d4ed86aa86a5597cbdd618295fafae87 to your computer and use it in GitHub Desktop.
Save vient/d4ed86aa86a5597cbdd618295fafae87 to your computer and use it in GitHub Desktop.

AVX2 Encoder

Intro

This is the write-up for the task "AVX2 Encoder" from TCTF (0CTF) Finals 2017.

Description

We are given the following files:

avx2_encoder.exe: PE32+ executable (console) x86-64, for MS Windows
flag.jpg.enc: data
key.dat: data

Executable takes input file and produces two files: encrypted input and some data used in encryption.

Pseudocode of encryption process:

key_f = open('key.dat')
SBox = list(range(256))
for _ in range(32):
    random_shuffle(SBox)
    key_f.write(SBox)
    for i in range(256):
        block[i] = encryptors[SBox[i]](block[i])

Actually, encryptors are not separate functions but a big switch. It uses AVX2 instructions to cipher data, so we do not get clean code in decompiler. Example of decompiled encryption code:

case 0:
  __asm
  {
    vmovdqu ymm0, cs:ymmword_14000C1E0+10E0h; jumptable 00000001400016A9 case 0
    vmovdqu ymm1, ymmword ptr [rbx]
  }
  i = 67i64;
  do
  {
    __asm
    {
      vpshufhw ymm0, ymm0, 2Dh
      vpsrlq  ymm2, ymm1, 1Dh
      vpshuflw ymm5, ymm0, 93h
      vpsllq  ymm0, ymm1, 23h
      vpermq  ymm1, ymm0, 39h
      vpxor   ymm2, ymm1, ymm2
      vpxor   ymm3, ymm2, ymm5
      vpsllq  ymm0, ymm3, 3Bh
      vpsrlq  ymm4, ymm3, 5
      vpermq  ymm1, ymm0, 39h
      vpxor   ymm2, ymm1, ymm4
      vpshufhw ymm3, ymm2, 93h
      vpshuflw ymm0, ymm3, 2Dh
      vpsllq  ymm2, ymm0, 3Eh
      vpsrlq  ymm0, ymm0, 2
      vpermq  ymm1, ymm0, 93h
      vpxor   ymm2, ymm1, ymm2
      vpsubq  ymm3, ymm2, ymm5
      vpsrlq  ymm0, ymm3, 16h
      vpsllq  ymm4, ymm3, 2Ah
      vpermq  ymm1, ymm0, 93h
      vpshufhw ymm0, ymm5, 6Ch
      vpshuflw ymm0, ymm0, 0D2h
      vpxor   ymm1, ymm1, ymm4
    }
    --i;
  }
  while ( i );
  __asm { vmovdqu ymmword ptr [rbx], ymm1 }
  break;

rbx points to input data blocks.

Solution

We are given all the code as well as used "secret key" (order of used encryptors for each block). With all this info we have several ways to reverse the encryption. The first one that comes to mind is to literally reverse all the AVX2 code somehow. The problem is, the data flow is not linear, in some places data is splitted in two registers to be combined a bit later. To easily overcome this problem (i.e., to avoid thinking) we can use some tool that will solve everything for us, given the right formulation of problem and input data. That tool is Z3 Theorem Prover.

The solution now goes as following:

  • Load all data (program, SBoxes, encrypted file, magic constants for encryptors)
  • Make solver for given encryptor
  • Feed it with magic and encrypted result
  • Get decrypted result

To achieve this flow I've rewritten all used AVX2 instructions in corresponding Z3 expressions. First of all, I created some essentials like definitions of registers and easy access to its parts:

def make_ymm(name):
    return BitVec(name, 32*8)

def vec_slice(vec, from_, to):
    return Extract(vec.size() - from_ - 1, vec.size() - to, vec)

def i64(ymm, index):
    index = ymm.size() // 64 - index - 1
    return vec_slice(ymm, index * 64, (index + 1) * 64)

Next go AVX2 operations. For example, vpshuflw (_mm256_shufflelo_epi16 (shuffle words in low 64 bits in each 128-bit part of YMM register in our case)), given two registers and imm value returns corresponding result in the form of big And:

def vpshuflw(reg1, reg2, imm):
    cond = True
    cond = And(cond, i64(reg1, 1) == i64(reg2, 1))
    cond = And(cond, i64(reg1, 3) == i64(reg2, 3))
    for i in range(4):
        idx, imm = imm & 3, imm >> 2
        cond = And(cond, i16(reg1, i + 0) == i16(reg2, idx + 0))
        cond = And(cond, i16(reg1, i + 8) == i16(reg2, idx + 8))
    return cond

Now we can parse the encryptor ASM code and stack operations one by one using these AVX2 operations-conditions. We use SSA form to correctly follow the flow, i.e., ymm0 = 1; ymm0 ^= 2 becomes YMM0_0 = 1; YMM0_1 = YMM0_0 ^ 2.

Example

vpshuflw ymm0, ymm0, 0D2h

First, create a solver:

s = Solver()

Next, take the first instruction, vpshufhw ymm0, ymm0, 0D2h. We parse it in tuple op = ("vpshuflw", "ymm0", "ymm0", 0xD2). Obtain all source registers, add new register version for destination register and obtain it as well:

reg2 = regs[op[2]][-1]
regs[op[1]].append(make_ymm('{}_{}'.format(op[1], len(regs[op[1]]))))
reg1 = regs[op[1]][-1]

Finally, add constraint to solver:

s.add(AVX2[op[0]](reg1, reg2, op[3]))

where AVX2[op[0]] points to function vpshuflw.

Thoughts about optimization

Solution can be optimized in several ways. First of all, it is trivially parallelized. Each block is encrypted with one of 256 encryptors -> we can decrypt them in up to 256 threads. Second, building a solver takes time. It is faster to build it one time and then use cached version (use push-pop contexts on cached solver).

Now comes the sad part. We need to note that each encryptor is a block of instructions repeated in a loop no more that ~250 times. So if typical length of code in loop is 25 lines then we get 6000 lines for the whole loop. Solver built for this code takes up to 3 gigabyte of RAM. If we want 16 threads we need 16*3=48 Gb of RAM to be sure that CPU will work on 100%. That all means that our parallelization efforts are severely limited by installed hardware.

Execution speed in combination with previous limitations gives terrible performance: in single-thread mode the solution takes about 130 hours to finish. I used 8 threads and computed the answer in ~16 hours. Needless to say, it was too slow and we didn't get the flag in time. On the other side, it was a nice example of using Z3 in some real situation, even if it performed not so fast.

Edit: As you can see in the process of creating a solver I'm not adding constraints one by one but instead combine them in one big And. It occurs that inserting a line cond = simplify(cond) before line 175 reduces execution time by 40% (i.e. it executes for ~80 hours on one core instead of 130 hours).

import os
import struct
from multiprocessing import Pool
from functools import lru_cache
import sys
sys.path.append(r"C:\tools\z3-4.4.1-x64-win\bin")
from z3 import *
PATH = r"C:\Users\vient\Desktop\ctf\0ctf\FINALS\02-seam-carve"
POOL_SIZE = 8
def make_ymm(name):
return BitVec(name, 32*8)
def slice(vec, from_, to):
return Extract(vec.size() - from_ - 1, vec.size() - to, vec)
def i8(ymm, index):
index = ymm.size() // 8 - index - 1
return slice(ymm, index * 8, (index + 1) * 8)
def i16(ymm, index):
index = ymm.size() // 16 - index - 1
return slice(ymm, index * 16, (index + 1) * 16)
def i32(ymm, index):
index = ymm.size() // 32 - index - 1
return slice(ymm, index * 32, (index + 1) * 32)
def i64(ymm, index):
index = ymm.size() // 64 - index - 1
return slice(ymm, index * 64, (index + 1) * 64)
# ============================== AVX2 ========================================
def vpxor(reg1, reg2, reg3):
return reg1 == (reg2 ^ reg3)
def vpsubq(reg1, reg2, reg3):
cond = True
for i in range(4):
cond = And(cond, i64(reg1, i) == URem(i64(reg2, i) - i64(reg3, i), 2**64))
return cond
def vpaddq(reg1, reg2, reg3):
cond = True
for i in range(4):
cond = And(cond, i64(reg1, i) == URem(i64(reg2, i) + i64(reg3, i), 2**64))
return cond
def vpsllq(reg1, reg2, imm):
cond = True
for i in range(4):
cond = And(cond, i64(reg1, i) == URem(i64(reg2, i) << imm, 2**64))
return cond
def vpsrlq(reg1, reg2, imm):
cond = True
for i in range(4):
cond = And(cond, i64(reg1, i) == URem(LShR(i64(reg2, i), imm), 2**64)) # !!! LShR instead of >>, >> is signed
return cond
def vpshufhw(reg1, reg2, imm):
cond = True
cond = And(cond, i64(reg1, 0) == i64(reg2, 0))
cond = And(cond, i64(reg1, 2) == i64(reg2, 2))
for i in range(4):
idx, imm = imm & 3, imm >> 2
cond = And(cond, i16(reg1, i + 4) == i16(reg2, idx + 4))
cond = And(cond, i16(reg1, i + 12) == i16(reg2, idx + 12))
return cond
def vpshuflw(reg1, reg2, imm):
cond = True
cond = And(cond, i64(reg1, 1) == i64(reg2, 1))
cond = And(cond, i64(reg1, 3) == i64(reg2, 3))
for i in range(4):
idx, imm = imm & 3, imm >> 2
cond = And(cond, i16(reg1, i + 0) == i16(reg2, idx + 0))
cond = And(cond, i16(reg1, i + 8) == i16(reg2, idx + 8))
return cond
def vpermq(reg1, reg2, imm):
cond = True
for i in range(4):
idx, imm = imm & 3, imm >> 2
cond = And(cond, i64(reg1, i) == i64(reg2, idx))
return cond
AVX2 = {
'vpxor': vpxor,
'vpsubq': vpsubq,
'vpaddq': vpaddq,
'vpsllq': vpsllq,
'vpsrlq': vpsrlq,
'vpshufhw': vpshufhw,
'vpshuflw': vpshuflw,
'vpermq': vpermq,
}
AVX2_intr = {
'vpxor': '_mm256_xor_si256',
'vpsubq': '_mm256_sub_epi64',
'vpaddq': '_mm256_add_epi64',
'vpsllq': '_mm256_slli_epi64',
'vpsrlq': '_mm256_srli_epi64',
'vpshufhw': '_mm256_shufflehi_epi16',
'vpshuflw': '_mm256_shufflelo_epi16',
'vpermq': '_mm256_permute4x64_epi64',
}
# ============================================================================
@lru_cache(maxsize=1024)
def parse_program(text):
lines = text.split('\n')
prog = []
for line in lines:
q = line.split()
q = [x.split(',')[0] for x in q]
if len(q) < 3:
continue
prog.append(tuple(q))
return prog
@lru_cache(maxsize=1)
def make_solver(prog):
regs = {
'ymm0': [make_ymm('ymm0_0')],
'ymm1': [make_ymm('ymm1_0')],
'ymm2': [make_ymm('ymm2_0')],
'ymm3': [make_ymm('ymm3_0')],
'ymm4': [make_ymm('ymm4_0')],
'ymm5': [make_ymm('ymm5_0')],
'ymm6': [make_ymm('ymm6_0')],
}
cond = True
for line in prog:
op, reg1, reg2 = line[:3]
reg2 = regs[reg2][-1] # take last version of register
try:
t = line[3]
if t[-1] == 'h':
t = t[:-1]
third = int('0x' + t, 16) # imm
except:
third = regs[line[3]][-1] # take last version of register
regs[reg1].append(make_ymm('{}_{}'.format(reg1, len(regs[reg1]))))
reg1 = regs[reg1][-1] # take last version of register
cond = And(cond, AVX2[op](reg1, reg2, third))
# print(cond)
# print(simplify(cond))
# exit(0)
s = Solver()
s.add(cond)
return (s, regs)
def reverse(text, n, init, end_):
text = text + '\n'
text = text * n # dummy cycle
s, regs = make_solver(tuple(parse_program(text)))
s.push()
for reg_name in init:
reg = regs[reg_name][0]
for i in range(4):
s.add(i64(reg, i) == init[reg_name][i])
for reg_name in end_:
reg = regs[reg_name][-1]
for i in range(4):
s.add(i64(reg, i) == end_[reg_name][i])
# print('[*] System created')
assert s.check() == sat
model = s.model()
# for reg_name in regs:
# reg = regs[reg_name][-1]
# if not model[reg]:
# continue
# print(reg, ' \t=', hex(model[reg].as_long())[2:].zfill(256 // 4))
# print('=' * 82)
# for reg_name in end_:
# reg = regs[reg_name][0]
# print(reg, ' \t=', hex(model[reg].as_long())[2:].zfill(256 // 4))
reg = regs[list(end_.keys())[0]][0]
val = model[reg].as_long()
res = []
for _ in range(4):
res.append(val % 2**64)
val //= 2**64
s.pop()
return res
def load_program(path):
with open(path, 'r') as f:
a = f.readlines()
res = []
for i in range(256):
it = [j for j in range(len(a)) if 'case {}:'.format('0x' + hex(i)[2:].upper() if i >= 10 else i) in a[j]][0]
while '__asm' not in a[it]:
it += 1
it += 2
init_reg = a[it][a[it].index('ymm'):][:4]
init_off = int(a[it][a[it].index('ymmword_') + 13:][:4], 16) - 0xC1E0
end_reg = a[it + 1][a[it + 1].index('ymm'):][:4]
# print('Case', i, 'init in', init_reg, 'from', hex(init_off))
while '=' not in a[it]:
it += 1
itt = a[it].index('=')
while a[it][itt] not in '0123456789':
itt += 1
nitt = itt
while a[it][nitt] in '0123456789':
nitt += 1
repeats = int(a[it][itt:nitt])
while '__asm' not in a[it]:
it += 1
it += 2
nit = it
while '}' not in a[nit]:
nit += 1
text = a[it:nit]
res.append({
'init_reg': init_reg,
'init_idx': init_off // 32,
'end_reg': end_reg,
'repeats': repeats,
'text': ''.join(text)
})
return res
def load_sboxes(path):
with open(path, 'rb') as f:
a = f.read()
res = []
for i in range(0, len(a), 256):
res.append([a[i:i + 256][j] for j in range(256)])
return res
def load_init(path):
with open(path, 'rb') as f:
a = f.read()
res = []
for i in range(0, len(a), 32):
val = int.from_bytes(a[i:i + 32], byteorder='little')
t = []
for _ in range(4):
t.append(val % 2**64)
val //= 2**64
res.append(t)
return res
def load_end(path):
with open(path, 'rb') as f:
a = f.read()
a = a[8:] # !!! skip size and CRC
res = []
for i in range(0, len(a), 32):
val = int.from_bytes(a[i:i + 32], byteorder='little')
t = []
for _ in range(4):
t.append(val % 2**64)
val //= 2**64
res.append(t)
return res
def worker(jobs):
it, jobs, PROG, SBOXES, INIT_VALUES = jobs
res = {i: [] for i in range(256)}
res['solved'] = {}
for job in jobs:
enc = job['enc']
sbox_idx = job['sbox_idx']
enc_idx = job['enc_idx']
steps = job['steps']
cur = PROG[it]
enc = reverse(
cur['text'],
cur['repeats'],
init={cur['init_reg']: INIT_VALUES[cur['init_idx']]},
end_={cur['end_reg']: enc}
)
if steps == 1:
res['solved'][sbox_idx * 256 + enc_idx] = struct.pack('<QQQQ', *enc)
else:
sbox = SBOXES[sbox_idx * 32 + steps - 2]
job['steps'] -= 1
job['enc'] = enc
res[sbox[enc_idx]].append(job)
return res
def main():
path = PATH
PROG = load_program(os.path.join(path, 'prog.txt'))
print('[*] Program loaded')
SBOXES = load_sboxes(os.path.join(path, 'key.dat'))
print('[*] SBoxes loaded')
INIT_VALUES = load_init(os.path.join(path, 'init.bin'))
print('[*] Init values loaded')
end_values = load_end(os.path.join(path, 'flag.jpg.enc'))
print('[*] End values loaded')
# print(list(map(hex, end_values[0])))
total = 0
queues = [[] for _ in range(256)]
for i, enc in enumerate(end_values):
total += 32
sbox_idx, enc_idx = i // 256, i % 256
sbox = SBOXES[sbox_idx * 32 + 31]
first_worker = sbox[enc_idx]
queues[first_worker].append({
'enc': enc,
'sbox_idx': sbox_idx,
'enc_idx': enc_idx,
'steps': 32 })
it, empty = 0, 0
res = {}
while any(len(x) > 0 for x in queues):
print('Total', total, 'left')
total -= sum(len(x) for x in queues)
with Pool(POOL_SIZE) as p:
q = p.map(worker, [(i, x, PROG, SBOXES, INIT_VALUES) for i, x in enumerate(queues)])
queues = [[] for _ in range(256)]
for worker_res in q:
res.update(worker_res['solved'])
for i in range(256):
queues[i] += worker_res[i]
with open(os.path.join(path, 'flag.jpg'), 'wb') as f:
for i in range(100000):
if i not in res:
break
_ = f.write(res[i])
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment