Skip to content

Instantly share code, notes, and snippets.

@ben-sb
Created May 20, 2024 19:23
Show Gist options
  • Save ben-sb/abd300461666b095578f58259fbf9b60 to your computer and use it in GitHub Desktop.
Save ben-sb/abd300461666b095578f58259fbf9b60 to your computer and use it in GitHub Desktop.
from ghidra.program.model.lang import OperandType
class Graph:
def __init__(self):
self.node_map = {}
self.edge_map = {}
def add_node(self, name):
if name in self.node_map:
return self.node_map[name]
else:
node = Node(name)
self.node_map[name] = node
return node
def add_edge(self, source, target, label):
edge = Edge(source, target, label)
if edge.name not in self.edge_map:
self.edge_map[edge.name] = edge
source.outgoing_edges.append(edge)
target.incoming_edges.append(edge)
def print_graphviz(self):
print('Digraph G {')
# colour start and edge nodes
for node in self.node_map.values():
if len(node.incoming_edges) == 0:
print(" %s [color=\"green\"]" % node.name)
elif len(node.outgoing_edges) == 0:
print(" %s [color=\"red\"]" % node.name)
for edge in self.edge_map.values():
print(" %s -> %s [label=\"%s\"]" %(edge.source.name, edge.target.name, edge.label))
print('}')
class Node:
def __init__(self, name):
self.name = name
self.incoming_edges = []
self.outgoing_edges = []
class Edge:
def __init__(self, source, target, label):
self.source = source
self.target = target
self.label = label
self.name = "%s --(%s)--> %s" % (self.source.name, self.label, self.target.name)
class PathFinder:
def __init__(self, start, end, length):
self.end = end
self.length = length
self.found_paths = []
self.find(start.outgoing_edges[:])
def find(self, path):
if len(path) < self.length:
edge = path[len(path) - 1]
node = edge.target
for edge in node.outgoing_edges:
new_path = path[:]
new_path.append(edge)
self.find(new_path)
else:
edge = path[len(path) - 1]
node = edge.target
if node == self.end:
self.found_paths.append(path)
# wrapper class used to store functions and their disassembled instructions together
class Function:
def __init__(self, function, instrs):
self.function = function
self.instrs = instrs
self.addr_map = self.build_addr_map()
def build_addr_map(self):
map = {}
for i, instr in enumerate(self.instrs):
map[str(instr.getAddress())] = i
return map
def extract_comparison_and_target_addr(function, index):
instr = function.instrs[index]
# check instr matches expected type: CMP EAX,0x?
if instr.getMnemonicString() != 'CMP' or instr.getNumOperands() != 2 or not (instr.getOperandType(1) & OperandType.SCALAR):
raise Exception('Failed to extract comparison and target from %s'%instr)
op = instr.getDefaultOperandRepresentation(1)
comparison = int(op, 16)
jump_target = extract_jump_target(function, index + 1)
return (comparison, jump_target)
def extract_jump_target(function, index):
instr = function.instrs[index]
mnemonic = instr.getMnemonicString()
# check instr matches expected type: JZ addr | JNZ addr
if mnemonic != 'JZ' and mnemonic != 'JNZ':
raise Exception('Failed to extract jump target from %s'%instr)
if mnemonic == 'JNZ':
next_index = index + 1 # we want to take matching path so just fall through to next instr
elif mnemonic == 'JZ':
# want to go to target of jump
jump_addr = instr.getDefaultOperandRepresentation(0)
formatted_addr = jump_addr.split('0x')[1]
if formatted_addr in function.addr_map:
next_index = function.addr_map[formatted_addr]
else:
raise Exception('Failed to find index for jump to %'%jump_addr)
return next_index
def extract_node(function, index):
instrs = function.instrs
i = index + 1
while i < len(instrs):
instr = instrs[i]
# ignore call to deref, this happens only first time
if str(instr) == 'CALL 0x0010e080':
return
mnemonic = instr.getMnemonicString()
num_operands = instr.getNumOperands()
if mnemonic == 'CMP' and num_operands == 2 and instr.getDefaultOperandRepresentation(0) == 'EAX' and instr.getOperandType(1) & OperandType.SCALAR:
(char_code, target_index) = extract_comparison_and_target_addr(function, i)
char = chr(char_code)
# find next comparison
i = target_index
while instrs[i].getMnemonicString() != 'CMP':
i += 1
if i >= len(instrs):
raise Exception('Ran out of instructions looking for next comparison')
(num, next_index) = extract_comparison_and_target_addr(function, i)
# find either instr setting up args for call to briefcase_no_x, or call to million dollars
i = next_index
while True:
next_instr = instrs[i]
# call to million dollars, leaf node
if str(next_instr) == 'CALL 0x0010c6e0':
node = graph.add_node('%s_%s' % (str(function.function.getName()), num))
end_node = graph.add_node('million_dollars')
graph.add_edge(node, end_node, char)
return
# setting up second arg to briefcase_no_x
elif next_instr.getMnemonicString() == 'MOV' and next_instr.getDefaultOperandRepresentation(0) == 'ESI':
break
else:
i += 1
if i >= len(instrs):
raise Exception('Ran out of instructions looking for next comparison')
op = next_instr.getDefaultOperandRepresentation(1)
next_num = int(op, 16)
next_instr = instrs[i + 1]
if next_instr.getMnemonicString() != 'CALL':
raise Exception('Expected briefcase function call, got %s'%next_instr)
call_addr = next_instr.getDefaultOperandRepresentation(0).split('0x')[1]
# find which briefcase function we are calling
if call_addr in function_addr_index_map:
target_index = function_addr_index_map[call_addr]
target_func = briefcase_functions[target_index]
else:
raise Exception('Expected briefcase function address, got'%call_addr)
node_name = '%s_%s' % (str(function.function.getName()), num)
source_node = graph.add_node(node_name)
target_name = '%s_%s' % (str(target_func.function.getName()), next_num)
target_node = graph.add_node(target_name)
graph.add_edge(source_node, target_node, char)
return
i += 1
raise Exception('Ran out of instructions looking for comparison')
if __name__ == '__main__':
currentProgram = getCurrentProgram()
listing = currentProgram.getListing()
function_manager = currentProgram.getFunctionManager()
briefcase_function_names = set(['briefcase_no_%d'%i for i in range(1, 6)])
function_addr_index_map = {} # maps addresses -> functions for all briefcase functions
briefcase_functions = []
graph = Graph()
# find all relevant briefcase functions
for function in function_manager.getFunctions(True):
name = function.getName()
if name in briefcase_function_names:
start_addr = function.getEntryPoint()
function_addr_index_map[str(start_addr)] = len(briefcase_functions) # index function will be
instrs = listing.getInstructions(function.getBody(), True)
func = Function(function, [instr for instr in instrs])
briefcase_functions.append(func)
# extract nodes and edges to build graph
for function in briefcase_functions:
# open('functions/%s.txt'%function.function.getName(), 'w').write('\n'.join(['%s: %s'%(i.getAddress(), str(i)) for i in function.instrs]))
for index, instr in enumerate(function.instrs):
# this pattern is loading the next char from the input string
if str(instr) == 'MOV EAX,dword ptr [RSP + 0x28]':
extract_node(function, index)
# add an edge from million_dollars to our end node with char '}', as we didn't handle this
million_node = graph.node_map['million_dollars']
end_node = graph.add_node('end')
graph.add_edge(million_node, end_node, '}')
graph.print_graphviz()
start_node = graph.node_map['briefcase_no_1_0']
end_node = end_node
desired_length = 25 # flag is 25 characters long
finder = PathFinder(start_node, end_node, desired_length)
# all of these paths work, only one of them is the right flag though
for path in finder.found_paths:
chars = [e.label for e in path]
print(''.join(chars))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment