Skip to content

Instantly share code, notes, and snippets.

@pieceofsummer
Created October 24, 2021 21:58
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save pieceofsummer/335cf841ed5fc8155a4067541660ceb6 to your computer and use it in GitHub Desktop.
Save pieceofsummer/335cf841ed5fc8155a4067541660ceb6 to your computer and use it in GitHub Desktop.
#TODO write a description for this script
#@author
#@category _NEW_
#@keybinding
#@menupath
#@toolbar
from ghidra.program.model.address import *
from ghidra.program.model.listing import *
from ghidra.program.model.symbol import *
from ghidra.program.model.data import *
from ghidra.program.model.pcode import *
from jarray import array
import struct
HELPER_BLOCK_NAME = 'helper'
HELPER_BLOCK_ADDRESS = 0x80000000
address_factory = currentProgram.getAddressFactory()
datatype_manager = currentProgram.getDataTypeManager()
function_manager = currentProgram.getFunctionManager()
def get_address(address):
return address_factory.getDefaultAddressSpace().getAddress(address)
def get_data_type(type):
data_type = datatype_manager.getDataType(datatype_manager.getRootCategory().getCategoryPath(), type)
assert data_type, 'type not found: %s' % type
return data_type
def create_argument(name, type, location=None):
if isinstance(location, str):
location = currentProgram.getRegister(location)
if isinstance(type, str):
type = get_data_type(type)
return ParameterImpl(name, type, location, currentProgram)
def create_arguments(*args):
return array(args, ghidra.program.model.listing.Variable).tolist()
def prepare_function(fn, return_type, calling_convention, name, *args):
if isinstance(return_type, str):
return_type = get_data_type(return_type)
storage_type = Function.FunctionUpdateType.DYNAMIC_STORAGE_ALL_PARAMS if calling_convention else Function.FunctionUpdateType.CUSTOM_STORAGE
fn.setName(name, SourceType.DEFAULT)
fn.setCallingConvention(calling_convention)
fn.replaceParameters(create_arguments(*args), storage_type, True, SourceType.DEFAULT)
fn.setReturnType(return_type, SourceType.DEFAULT)
# =====================================================================================================================================================
# Stage 0: Prepare resolve functions
# =====================================================================================================================================================
resolve_addr = get_address(0x4054b0)
resolve_func = getFunctionAt(resolve_addr)
prepare_function(resolve_func, 'void *', '__fastcall', 'resolve_api', create_argument('library_hash', 'uint', 'ECX'), create_argument('function_hash', 'uint', 'EDX'))
helper_addr = get_address(HELPER_BLOCK_ADDRESS)
if getMemoryBlock(HELPER_BLOCK_NAME) is None:
currentProgram.getMemory().createInitializedBlock(HELPER_BLOCK_NAME, helper_addr, 0x10000, 0, monitor, False)
# xchg ecx, edx; call resolve_api; ret
code = struct.pack('<BBBiB', 0x87, 0xd1, 0xe8, resolve_addr.subtract(helper_addr) - 7, 0xc3)
clearListing(helper_addr, helper_addr.add(len(code)))
setBytes(helper_addr, code)
disassemble(helper_addr)
createFunction(helper_addr, 'resolve_api_hlp')
helper_func = getFunctionAt(helper_addr)
prepare_function(helper_func, 'void *', None, 'resolve_api_hlp', create_argument('library_hash', 'uint', 'EDX'), create_argument('function_hash', 'uint', 'ECX'))
import_addr = helper_func.getBody().getMaxAddress().add(1)
import_addr = import_addr.add(import_addr.getPointerSize() - import_addr.getOffset() % import_addr.getPointerSize())
# skip already defined imports
while getByte(import_addr) & 0xff == 0xc3:
import_addr = import_addr.add(import_addr.getPointerSize())
# force define initializers as functions
initializer_addr = get_address(0x445154)
while getInt(initializer_addr):
addr = get_address(getInt(initializer_addr) if initializer_addr.getPointerSize() == 4 else getLong(initializer_addr))
createFunction(addr, 'FUN_%s' % addr)
initializer_addr = initializer_addr.add(initializer_addr.getPointerSize())
# =====================================================================================================================================================
# Stage 1: Process defined functions and rewrite exceptions with resolve_api_hlp calls
# =====================================================================================================================================================
def rewrite_call(addr, end_addr, helper_addr):
# call resolve_api_helper; call eax
code = struct.pack('<BiBB', 0xe8, helper_addr.subtract(addr) - 5, 0xff, 0xd0)
# instructions following exception block may be incorrect
to_clear = min(64, end_addr.subtract(addr))
clearListing(addr, addr.add(to_clear))
setBytes(addr, code)
disassemble(addr)
def process_function(fn, helper_addr):
end_addr = getFunctionAfter(fn).getEntryPoint()
changed = False
instr = getFirstInstruction(fn)
while instr and instr.getMinAddress() < end_addr:
opcode = instr.getMnemonicString()
if opcode == 'XOR':
reg, reg2 = instr.getRegister(0), instr.getRegister(1)
if reg and reg2 and reg == reg2:
next_instr = instr.getNext()
if not next_instr:
# no instructions defined after XOR???
break
next_opcode = next_instr.getMnemonicString()
if next_opcode == 'MOV':
reg_ref = next_instr.getOpObjects(1) if next_instr.getOperandType(1) & 0x400000 else None
if reg_ref and len(reg_ref) == 1 and reg_ref[0] == reg:
rewrite_call(instr.getMinAddress(), end_addr, helper_addr)
changed = True
elif next_opcode == 'DIV':
reg_ref = next_instr.getOpObjects(0) if next_instr.getOperandType(0) == 0x200 else None
if reg_ref and len(reg_ref) == 1 and reg_ref[0] == reg:
rewrite_call(instr.getMinAddress(), end_addr, helper_addr)
changed = True
elif opcode == 'JMP':
next_instr = instr.getNext()
if not next_instr:
# no instructions defined after JMP (could be a short jump forward)
jmp_addr = instr.getOpObjects(0)[0] if instr.getOperandType(0) == 0x2040 else None
if jmp_addr and jmp_addr > instr.getMinAddress() and jmp_addr < end_addr:
instr = getInstructionAt(jmp_addr)
continue
instr = instr.getNext()
return changed
for fn in currentProgram.getListing().getFunctions(False):
if fn.isExternal() or fn.getEntryPoint() >= helper_addr:
continue
if process_function(fn, helper_addr):
print('Function %s was patched' % fn.getName())
# =====================================================================================================================================================
# Stage 2: Resolve WinAPI imports and rewrite indirect calls with normal ones
# =====================================================================================================================================================
def name_hash(name):
hash = 0x40
for c in name:
hash = (ord(c) - hash * 0x45523f21) & 0xffffffff
return hash
def get_data_manager(name):
service = state.getTool().getService(ghidra.app.services.DataTypeManagerService)
for dm in service.getDataTypeManagers():
if dm.getName() == name:
return dm
return None
def get_operand(fn, val):
if val.isConstant():
return val.getOffset()
elif val.isUnique():
op = val.getDef()
if op.getMnemonic() == 'CAST':
return get_operand(fn, op.getInput(0))
elif op.getMnemonic() == 'PTRSUB':
return get_operand(fn, op.getInput(0)) + get_operand(fn, op.getInput(1))
raise Exception('Unhandled operand type: %s' % op)
def rewrite_call(addr, proto, is_helper_call):
global import_addr
print(proto)
func = getFunction(proto.getName())
if func is None:
# no such function is defined yet
setBytes(import_addr, struct.pack('<BBBB', 0xc3, 0x90, 0x90, 0x90))
args = []
for arg in proto.getArguments():
args.append(create_argument(arg.getName(), arg.getDataType(), 0))
func = createFunction(import_addr, proto.getName())
prepare_function(func, proto.getReturnType(), '__stdcall', proto.getName(), *args)
import_addr = import_addr.add(import_addr.getPointerSize())
if is_helper_call:
code = struct.pack('<BiBB', 0xe8, func.getEntryPoint().subtract(addr) - 5, 0x90, 0x90)
clearListing(addr, addr.add(len(code)))
setBytes(addr, code)
disassemble(addr)
else:
code = struct.pack('<BI', 0xb8, func.getEntryPoint().getOffset())
clearListing(addr, addr.add(len(code)))
setBytes(addr, code)
disassemble(addr)
def rewrite_calls(decompiler, fn, is_helper_call):
for ref in getReferencesTo(fn.getEntryPoint()):
if ref.getReferenceType() != RefType.UNCONDITIONAL_CALL:
continue
ref_addr = ref.getFromAddress()
ref_fn = getFunctionBefore(ref_addr)
if not ref_fn:
continue
result = decompiler.decompileFunction(ref_fn, 30, monitor)
if not result.decompileCompleted():
continue
high_fn = result.getHighFunction()
try:
call_op = next(high_fn.getPcodeOps(ref_addr))
except:
print(ref_addr, 'No function call found')
continue
if call_op.getMnemonic() != 'CALL':
print(ref_addr, 'Not a function call', call_op)
continue
_, lib_hash, func_hash = call_op.getInputs()
if not func_hash.isConstant():
print(ref_addr, 'Non-constant operand', lib_hash, func_hash)
continue
lib_hash = get_operand(high_fn, lib_hash)
func_hash = get_operand(high_fn, func_hash)
print(ref_addr, fn.getEntryPoint(), hex(lib_hash), hex(func_hash))
proto = func_hashes.get(func_hash)
if not proto:
print(ref_addr, 'Prototype not found for hash 0x%x' % func_hash)
continue
rewrite_call(ref_addr, proto, is_helper_call)
dm = get_data_manager('windows_vs12_32')
if dm:
func_hashes = {}
for dt in dm.getAllDataTypes():
if not isinstance(dt, FunctionDefinition):
continue
func_hashes[name_hash(dt.getName())] = dt
decompiler = ghidra.app.decompiler.DecompInterface()
decompiler.openProgram(currentProgram)
rewrite_calls(decompiler, resolve_func, False)
rewrite_calls(decompiler, helper_func, True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment