Skip to content

Instantly share code, notes, and snippets.

@mak

mak/arch.py Secret

Last active November 11, 2024 04:33
flare-on 2024 - c10
import enum
from binaryninja import LowLevelILLabel
from binaryninja.architecture import Architecture
from binaryninja.function import RegisterInfo, InstructionInfo, InstructionTextToken
from binaryninja.enums import InstructionTextTokenType, FlagRole, LowLevelILFlagCondition, BranchType
from binaryninja.log import log_info, log_warn
from dataclasses import dataclass
from typing import Optional, Iterator
Opcode = enum.IntEnum("Opcode", ["PUSH", "PUSH_DATA", "ADD_DATA", "STORE_AT", "LOAD", "STORE", "DUP", "DROP_STACK", "ADD",
"ADD_IMM", "SUB", "DIV", "MULT", "JMP", "JNZ", "JZ", "EQ", "GT", "GE", "LT", "LE", "LTc",
"RET", "HLT", "NOP", "XOR", "OR", "AND", "MOD","SHL", "SHR", "ROL4", "ROR4", "ROL2",
"ROR2", "ROL1", "ROR1", "PUTCHAR"])
@dataclass
class Instr:
mnem : Opcode
arg: Optional[int] = None
is_addr: Optional[bool] = False
size : int = 1
def make_binop(self, il):
a = il.pop(8)
b = il.pop(8)
op = self.get_op(il)
return il.push(8, op(8, a, b))
def get_op(self, il):
match self.mnem:
case Opcode.SHL:
return il.shift_left
case Opcode.SHR:
return il.logical_shift_right
case Opcode.ADD | Opcode.SUB | Opcode.MULT:
return getattr(il, self.mnem.name.lower())
case Opcode.XOR | Opcode.OR | Opcode.AND:
return getattr(il, f'{self.mnem.name.lower()}_expr')
case Opcode.DIV:
return getattr(il, f'{self.mnem.name.lower()}_unsigned')
case Opcode.EQ:
return il.compare_equal
case Opcode.GT:
return il.compare_unsigned_greater_than
case Opcode.GE:
return il.compare_unsigned_greater_than_equal
case Opcode.LT:
return il.compare_unsigned_less_than
case Opcode.GE:
return il.compare_unsigned_less_than_equal
def disasm(data: bytes, addr) -> Optional[Instr]:
op = data[0]
try:
i = Instr(mnem=Opcode(op))
except ValueError:
log_warn(f"Wrong opcode {op:x} @ {addr:x}")
return None
if op == 0x16 or 1 <= op <= 4 or op == 0xa or 0xe <= op <= 0x10:
i.arg = data[2] + (data[1] << 8)
i.is_addr = 0xe <= op <= 0x10 or op == 2 or op == 3
i.size = 3
# log_info(f'{i} @ {addr:x}')
return i
class F2024_C10(Architecture):
name = 'C4T'
address_size = 2 # 16-bit addresses
default_int_size = 2 # 1-byte integers
instr_alignment = 1 # no instruction alignment
max_instr_length = 3
regs = {
'SP': RegisterInfo('SP', 8),
'TMP': RegisterInfo('TMP', 8)
}
stack_pointer = 'SP'
def get_instruction_info(self, data, addr):
i = disasm(data, addr)
if not i:
return None
result = InstructionInfo()
result.length = i.size
if i.mnem == Opcode.JMP:
result.add_branch(BranchType.UnconditionalBranch,i.arg)
elif i.mnem == Opcode.JNZ or i.mnem == Opcode.JZ:
result.add_branch(BranchType.TrueBranch, i.arg)
result.add_branch(BranchType.FalseBranch, addr + result.length)
return result
def get_instruction_text(self, data, addr):
i = disasm(data, addr)
if not i:
return None
tokens = [InstructionTextToken(InstructionTextTokenType.InstructionToken, i.mnem.name)]
if i.arg is not None:
tokens.append(InstructionTextToken(InstructionTextTokenType.TextToken, ' '))
if i.is_addr:
tokens.append(InstructionTextToken(InstructionTextTokenType.PossibleAddressToken, hex(i.arg)))
else:
tokens.append(InstructionTextToken(InstructionTextTokenType.IntegerToken, str(i.arg)))
return tokens, i.size
#
def get_instruction_low_level_il(self, data: bytes, addr: int, il: 'lowlevelil.LowLevelILFunction') -> Optional[
int]:
i = disasm(data, addr)
if not i:
return None
new_il = None
match i.mnem:
case Opcode.PUSH:
new_il = il.push(8, il.const(2, i.arg))
case Opcode.NOP:
new_il = il.nop()
case Opcode.SHL | Opcode.SHR | Opcode.ADD | Opcode.SUB | Opcode.XOR | Opcode.OR | Opcode.AND | Opcode.MULT | Opcode.EQ | Opcode.GT | Opcode.LT | Opcode.LE |Opcode.LE :
new_il = i.make_binop(il)
case Opcode.ADD_IMM:
new_il = il.push(il.add(il.pop()), i.arg)
case Opcode.JMP:
new_il = il.jump(il.const_pointer(2, i.arg))
case Opcode.STORE:
il.append(il.set_reg(9, "TMP", il.pop(8)))
new_il = il.store(8, il.add(8, il.const(8,0x8000),il.mult(8,il.pop(8), il.const(8,8))), il.reg(8, "TMP"))
case Opcode.LOAD:
new_il = il.push(8, il.load(8, il.add(8,il.const(8, 0x8000),il.mult(8,il.pop(8), il.const(8, 8)))))
case Opcode.JZ | Opcode.JNZ:
t = LowLevelILLabel()
f = LowLevelILLabel()
v = il.pop(8)
v = il.compare_equal(8, v, 0 ) if i.mnem == Opcode.JZ else il.compare_not_equal(8, v, 0)
il.append(il.if_expr(v, t, f))
il.mark_label(t)
il.append(il.jump(il.const_pointer(2, i.arg)))
il.mark_label(f)
case _:
new_il = il.unimplemented()
if new_il:
il.append(new_il)
return i.size
F2024_C10.register()
import sys
import enum
import operator
from inspect import stack
from typing import Optional
from dataclasses import dataclass
Opcode = enum.IntEnum("Opcode", ["PUSH", "PUSH_DATA", "ADD_DATA", "STORE_AT", "LOAD", "STORE", "DUP", "DROP_STACK", "ADD",
"ADD_IMM", "SUB", "DIV", "MUL", "JMP", "JNZ", "JZ", "EQ", "LT", "LE", "GT", "GE", "LTc",
"RET", "HLT", "NOP", "XOR", "OR", "AND", "MOD","SHL", "SHR", "ROL4", "ROR4", "ROL2",
"ROR2", "ROL1", "ROR1", "PUTCHAR"])
rol = lambda val, r_bits, max_bits: \
(val << r_bits % max_bits) & (2 ** max_bits - 1) | \
((val & (2 ** max_bits - 1)) >> (max_bits - (r_bits % max_bits)))
# Rotate right: 0b1001 --> 0b1100
ror = lambda val, r_bits, max_bits: \
((val & (2 ** max_bits - 1)) >> r_bits % max_bits) | \
(val << (max_bits - (r_bits % max_bits)) & (2 ** max_bits - 1))
@dataclass
class Instr:
mnem : Opcode
arg: Optional[int] = None
is_addr: Optional[bool] = False
size : int = 1
def disasm_one(data: bytes, addr) -> Optional[Instr]:
op = data[addr]
try:
i = Instr(mnem=Opcode(op))
except ValueError:
log_warn(f"Wrong opcode {op:x} @ {addr:x}")
return None
if op == 0x16 or 1 <= op <= 4 or op == 0xa or 0xe <= op <= 0x10:
i.arg = data[addr+2] + (data[addr+1] << 8)
i.is_addr = 0xe <= op <= 0x10 or op == 2 or op == 3
i.size = 3
# log_info(f'{i} @ {addr:x}')
return i
def disasm(data, addr):
pc = addr
while pc < len(data):
i = disasm_one(data, pc)
yield i
pc += i.size
with open(sys.argv[1], 'rb') as f: code = bytearray(f.read())
def get_op(i):
match i.mnem:
case Opcode.SHL:
return lambda a, b: a << b
case Opcode.SHR:
return lambda a, b : a >> b
case Opcode.ADD | Opcode.SUB | Opcode.MUL | Opcode.XOR:
return getattr(operator, i.mnem.name.lower())
case Opcode.OR | Opcode.AND:
return getattr(operator, f'{i.mnem.name.lower()}_')
case Opcode.DIV:
return lambda a, b: a // b
case Opcode.MOD:
return lambda a, b: a % b
case Opcode.EQ | Opcode.GT |Opcode.LE|Opcode.LT|Opcode.GE:
return lambda a, b: int(getattr(operator, i.mnem.name.lower())(a, b))
case Opcode.ROL4 | Opcode.ROL2 | Opcode.ROL1:
return lambda a, b: rol(a, b, int(i.mnem.name[-1])*8)
case Opcode.ROR4 | Opcode.ROR2 | Opcode.ROR1:
return lambda a, b: ror(a, b, int(i.mnem.name[-1])*8)
key = sys.argv[2].encode()
offsets = ((5, 0), (4, 1), (12, 2), (11, 3), (19, 4), (18, 5), (26, 6),(25, 7), (33, 8),(32, 9),(40,10),(39, 11),
(47,12), (46, 13), (54, 14), (53, 15))
for i, j in offsets:
code[i] = key[j]
def run(code, break_addr):
STACK = []
MEMORY = {}
SP = 0
def pop():
return STACK.pop()
def push(v):
STACK.append(v)
def eval_binop(i):
a = pop()
b = pop()
op = get_op(i)
return push(op(b, a))
pc = 0
while True:
i = disasm_one(code, pc)
# print(hex(pc), i)
# print(list(map(hex, STACK)))
if pc == break_addr:
return STACK, MEMORY
# if pc in (0xff, 0x133, 0x195, 0x1c9, 0x0x249):
# print('XXX', hex(pc), list(map(hex, STACK)))
match i.mnem:
case Opcode.PUSH:
push(i.arg)
case Opcode.NOP:
pass
case Opcode.SHL | Opcode.SHR | Opcode.ADD | Opcode.SUB | Opcode.XOR | Opcode.OR | Opcode.AND | Opcode.MUL:
eval_binop(i)
case Opcode.GT | Opcode.LT | Opcode.LE | Opcode.LE | Opcode.DIV | Opcode.MOD:
eval_binop(i)
case Opcode.ROL1 | Opcode.ROR1 |Opcode.ROL2 | Opcode.ROR2 |Opcode.ROL4 | Opcode.ROR4:
eval_binop(i)
case Opcode.ADD_IMM:
v = pop()
push(v + i.arg)
case Opcode.STORE:
what = pop()
where = pop()
MEMORY[where] = what
case Opcode.LOAD:
where = pop()
push(MEMORY[where])
case Opcode.JMP:
pc = i.arg
continue
case Opcode.JZ | Opcode.JNZ:
v = pop()
# if pc == 0x1fe:
# pc = i.arg
# continue
f = v == 0 if i.mnem == Opcode.JZ else v != 0
if f:
pc = i.arg
continue
case Opcode.PUTCHAR:
print(chr(pop()), end='')
case Opcode.HLT:
break
case Opcode.EQ :
eval_binop(i)
case _:
raise NotImplementedError(str(i))
pc += i.size
def stage1(inp):
x = 0x1505
for i in range(4):
x = (x << 5) + x + inp[i]
return x
def stage2(inp):
x = 0
for i in range(4):
x = ror(x, 13, 32) + inp[i]
return x
def stage3(inp):
x = 1
y = 0
for i in range(8):
x = (x + inp[i]) % 65521
y = (y + x) % 65521
return (y << 16) | x
def stage4(inp):
p = 0x1000193
h = 0x811c9dc5
for i in rage(16):
h = (h * p) % 0x100000000
h = h ^ inp[i]
return h
#print(hex(stage2(b'BAAA')))
stack, _ = run(code, -1)
print(list(map(hex, stack)))
# for a in range(0x21, 0x7f):
# for b in range(0x21, 0x7f):
# for c in range(0x21, 0x7f):
# for d in range(0x21, 0x7f):
# foo = [ a, b,c,d]
# for i, k in offsets[:4]:
# code[i] = foo[k]
# stack, _ = run(code, 0x133)
# if stack[-1] == stack[-2]:
# print(a,b,c,d)
# break
import z3
def stage1(inp):
x = 0x1505
for i in range(4):
x = (x << 5) + x + inp[i]
return x
def ror(val, r_bits, max_bits):
return (((val & (2 ** max_bits - 1)) >> r_bits % max_bits)) | (val << (max_bits - (r_bits % max_bits)) & (2 ** max_bits - 1))
def stage2(inp):
x = 0
for i in range(4):
x = ror(x, 13, 32) + inp[4+i]
return x
def stage3(inp):
x = 1
y = 0
for i in range(8):
x = (x + inp[8+i]) % 65521
y = (y + x) % 65521
return (y << 16) | x
def stage4(inp):
p = 0x1000193
h = 0x811c9dc5
for i in range(16):
h =( h * p)&0xffffffff
h = h ^ inp[i]
return h
def all_smt(s, initial_terms):
def block_term(s, m, t):
s.add(t != m.eval(t, model_completion=True))
def fix_term(s, m, t):
s.add(t == m.eval(t, model_completion=True))
def all_smt_rec(terms):
if z3.sat == s.check():
m = s.model()
yield m
for i in range(len(terms)):
s.push()
block_term(s, m, terms[i])
# term[i] should not be the same as that in m
for j in range(i):
fix_term(s, m, terms[j])
yield from all_smt_rec(terms[i:])
# we are yet to discover all the assignments for term[i]. Using term[i+1:] means we are skipping
# past the first satisfying assignment of term[i]
# Note that term[i] might be multivalued and not binary
s.pop()
yield from all_smt_rec(list(initial_terms))
s = z3.Solver()
inp = [ z3.BitVec(f'a{i}',64) for i in range(16)]
for a in inp: s.add(z3.Or(z3.And(a > 0x29, a<0x3a), z3.And(a>0x40, a<0x5b), z3.And(a>0x60, a<0x7b)))
keyprefix = b'VerYDumB'
for i, b in enumerate(keyprefix):
s.add(inp[i] == b)
#s.add(stage1(inp)&0xffffffff == 0x7c8df4cb)
#s.add(stage2(inp)&0xffffffff == 0x8b681d82)
s.add(stage3(inp)&0xffffffff == 0xf910374)
s.add(stage4(inp)&0xffffffff == 0x31f009d2)
assert s.check() == z3.sat
m = s.model()
x = bytes(m[a].as_long() for a in inp)
print(x)
print(stage4(x)&0xffffffff == 0x31f009d2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment