Skip to content

Instantly share code, notes, and snippets.

@matthw
Created May 1, 2023 17:19
Show Gist options
  • Save matthw/58b5f3a52737d4bcd8ffe2661a8fc3a2 to your computer and use it in GitHub Desktop.
Save matthw/58b5f3a52737d4bcd8ffe2661a8fc3a2 to your computer and use it in GitHub Desktop.
FCSC2023 - Chaussette
from unicorn import *
from unicorn.x86_const import *
from capstone import *
from pwn import *
import copy
from z3 import *
import sys
import time
UINT_MAX = 0xffffffffffffffff
rol = lambda val, r_bits, max_bits: \
(val << r_bits%max_bits) & (2**max_bits-1) | \
((val & (2**max_bits-1)) >> (max_bits-(r_bits%max_bits)))
# Rotate right: 0b1001 --> 0b1100
ror = lambda val, r_bits, max_bits: \
((val & (2**max_bits-1)) >> r_bits%max_bits) | \
(val << (max_bits-(r_bits%max_bits)) & (2**max_bits-1))
def dump(blk):
print("-------------------------------------")
for i in blk:
print(i)
class Lolz:
def __init__(self, data):
self.blocks = []
self.state = {
"rax": BitVec('rax', 64),
"rdi": BitVec('rdi', 64),
"rsi": BitVec('rsi', 64)
}
self.data = data
def meep(self):
insns = self.disasm(self.data)
self.parse(insns)
i = 0
for b in self.blocks:
#print('==============================')
i += 1
print("\r%s -> %d/%d"%(time.asctime(), i, len(self.blocks)), end='')
self.eval_block(b, self.state)
#print(self.state)
print("")
#print(self.state)
out = {}
for x in ('rsi', 'rdi'):
if type(self.state[x]) is int:
out[x] = self.state[x]
return out
def disasm(self, code):
insns = []
md = Cs(CS_ARCH_X86, CS_MODE_64)
for i in md.disasm(code, 0x0):
insns.append([i.mnemonic, i.op_str.replace(",", "").split()])
#print("%s\t%s" %(i.mnemonic, i.op_str))
return insns
def parse(self, insns):
# walk backward up to movabs
block = []
for (i, ops) in insns[::-1]:
block.insert(0, (i, ops))
# movabs is the start of a block
if i == 'movabs' and ops[0] == 'rax':
self.blocks.append(block)
block = []
# append first block
self.blocks.append(block)
def eval_block(self, block, st):
#dump(block)
# initialize new state
old_state = copy.copy(st)
# symbolyze inputs
st['rdi'] = BitVec('rdi', 64)
st['rsi'] = BitVec('rsi', 64)
for (i, dat) in block:
match i:
case 'movabs':
reg = dat[0]
imm = dat[1]
if imm.startswith('0x'):
imm = imm[2:]
st[reg] = int(imm, 16)
case 'xor':
st[dat[0]] ^= st[dat[1]]
case 'add':
st[dat[0]] += st[dat[1]]
#st[dat[0]] &= UINT_MAX
case 'sub':
st[dat[0]] -= st[dat[1]]
#st[dat[0]] &= UINT_MAX
case 'ror':
imm = dat[1]
if imm.startswith('0x'):
imm = imm[2:]
if type(st[dat[0]]) is int:
st[dat[0]] = ror(st[dat[0]], int(imm, 16), 64)
else:
st[dat[0]] = RotateRight(st[dat[0]], int(imm, 16))
case 'rol':
imm = dat[1]
if imm.startswith('0x'):
imm = imm[2:]
if type(st[dat[0]]) is int:
st[dat[0]] = rol(st[dat[0]], int(imm, 16), 64)
else:
st[dat[0]] = RotateLeft(st[dat[0]], int(imm, 16))
case 'dec':
st[dat[0]] -= 1
#st[dat[0]] &= UINT_MAX
case 'inc':
st[dat[0]] += 1
#st[dat[0]] &= UINT_MAX
case 'neg':
st[dat[0]] *= -1
#st[dat[0]] &= UINT_MAX
case 'or':
try:
imm = dat[1]
if imm.startswith('0x'):
imm = imm[2:]
st[dat[0]] |= int(imm, 16)
except:
st[dat[0]] |= st[dat[1]]
case 'and':
try:
imm = dat[1]
if imm.startswith('0x'):
imm = imm[2:]
st[dat[0]] &= int(imm, 16)
except:
st[dat[0]] &= st[dat[1]]
case 'shr':
imm = dat[1]
if imm.startswith('0x'):
imm = imm[2:]
if type(st[dat[0]]) is int:
#st[dat[0]] = LShR(st[dat[0]], int(imm, 16))
st[dat[0]] >>= int(imm, 16)
else:
st[dat[0]] = LShR(st[dat[0]], int(imm, 16))
case 'shl':
imm = dat[1]
if imm.startswith('0x'):
imm = imm[2:]
st[dat[0]] <<= int(imm, 16)
st[dat[0]] &= UINT_MAX
case 'mul':
st['rax'] *= st[dat[0]]
#st['rax'] &= UINT_MAX
case 'mov':
st[dat[0]] = st[dat[1]]
case 'xchg':
(st[dat[0]], st[dat[1]]) = (st[dat[1]], st[dat[0]])
case 'ret':
self.solve_block(st, None, end=True)
return
case _:
print(i, dat)
raise
self.solve_block(st, old_state)
def solve_block(self, state, prev_state, end=False):
#print("state: %r"%state)
#print("prev_state: %r"%prev_state)
s = Solver()
if end == True:
s.add(state['rax'] == 0)
else:
for r in ('rdi', 'rsi'):
s.add(state[r] == prev_state[r])
#print(s.sexpr())
assert s.check() == sat
m = s.model()
#print(m)
for v in m:
state[v.name()] = m[v].as_long()
def main():
n1 = 9243637858070793867
n2 = 12
io = remote("challenges.france-cybersecurity-challenge.fr", 2250)
r = 0
while True:
print("sending: %d %d"%(n1, n2))
#input()
io.send(p64(n1))
io.send(p64(n2))
size = u64(io.recv(8))
print("receving: %x"%size)
if size == 0xffffffffffffffff:
print(io.recv(4096))
break
code = b''
while len(code) < size:
code += io.recv(size - len(code))
print(hex(len(code)))
open('%d.bin'%r, 'wb').write(code)
r += 1
print("try: %d"%r)
s = Lolz(code)
out = s.meep()
if 'rdi' in out:
n1 = out['rdi']
if 'rsi' in out:
n2 = out['rsi']
main()
sys.exit(1)
lolz = Lolz(open(sys.argv[1], 'rb').read())
print(lolz.meep())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment