Skip to content

Instantly share code, notes, and snippets.

@Donaldduck8
Created March 17, 2024 21:52
Show Gist options
  • Save Donaldduck8/4970011e52a6646f9b0362b4cc6cb156 to your computer and use it in GitHub Desktop.
Save Donaldduck8/4970011e52a6646f9b0362b4cc6cb156 to your computer and use it in GitHub Desktop.
String decryptor for CargoBay malware
import base64
import re
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import Generator, List, Tuple
import pefile
from capstone import CS_ARCH_X86, CS_MODE_64, Cs, CsInsn
def get_section(pe: pefile.PE, section_name: str) -> pefile.SectionStructure:
for section in pe.sections:
if bytearray(section_name, encoding="latin-1") in section.Name:
return section
return None
def disassemble_blob(cs: Cs, bytes_blob: bytes) -> list[CsInsn]:
instructions = list(cs.disasm(bytes_blob, 0x0))
# cs.skipdata decided to stop working for me, so here's similar behavior
# If we have not reached the end of bytes_blob, we need to continue disassembling
while instructions[-1].address + len(instructions[-1].bytes) < len(bytes_blob):
# Try skipping forward 1 byte
new_starting_point = instructions[-1].address + len(instructions[-1].bytes) + 1
while True:
more_insns = list(
cs.disasm(bytes_blob[new_starting_point:], new_starting_point)
)
# If we got precisely zero new instructions, we need to skip forward another byte
if len(more_insns) == 0:
new_starting_point += 1
continue
else:
instructions += more_insns
break
return instructions
def find_pattern(instructions: List[CsInsn], patterns: List[str]) -> Generator[int, None, None]:
for i in range(len((instructions))):
captured_groups = []
for offset, pattern in enumerate(patterns):
current_insn = instructions[i + offset]
current_insn_meaning = current_insn.mnemonic + " " + current_insn.op_str
m = re.match(pattern, current_insn_meaning)
if not m:
break
# Did m capture any extra groups?
if len(m.groups()) > 0:
captured_groups += list(m.groups())
else:
if len(captured_groups) == 0:
yield i
else:
yield (i,) + tuple(captured_groups)
class StringDecryptor(ABC):
data: bytearray
pe: pefile.PE
section_name: str
section: pefile.SectionStructure
cs: Cs
instructions: List[CsInsn]
def __init__(
self,
data: bytearray,
section_name: str,
arch: int = CS_ARCH_X86,
mode: int = CS_MODE_64,
):
self.data = data
self.pe = pefile.PE(data=data)
self.section_name = section_name
self.section = get_section(self.pe, self.section_name)
self.cs = Cs(arch, mode)
self.cs.skipdata = True
self.cs.detail = False
self.instructions = disassemble_blob(self.cs, self.section.get_data())
@abstractmethod
def find_points_of_interest(self) -> List[int]:
pass
@abstractmethod
def decrypt_string(self, instructions: List[CsInsn], point_of_interest: int) -> str:
pass
def convert_rva_to_ida(self, addr: int) -> str:
return addr + self.section.VirtualAddress + self.pe.OPTIONAL_HEADER.ImageBase
def convert_rip_plus_offset_to_absolute(self, insn: CsInsn) -> int:
insn_rva = insn.address
insn_offset = int(re.match(r"^.+?\+ (.+?)\]", insn.op_str).group(1), 16)
insn_length = len(insn.bytes)
return (
self.pe.get_offset_from_rva(insn_rva + insn_offset + insn_length)
+ self.section.VirtualAddress
)
def convert_rip_plus_offset_to_ida(self, insn: CsInsn) -> int:
insn_rva = insn.address
insn_offset = int(re.match(r"^.+?\+ (.+?)\]", insn.op_str).group(1), 16)
insn_length = len(insn.bytes)
return insn_rva + insn_offset + insn_length + self.section.VirtualAddress + self.pe.OPTIONAL_HEADER.ImageBase
class CargoBayStringDecryptor(StringDecryptor):
def find_points_of_interest(self) -> List[int]:
self.find_decryption_function()
points_of_interest = []
for i, insn in enumerate(self.instructions):
insn_addr = self.convert_rva_to_ida(insn.address)
if insn.mnemonic != "call" or not insn.op_str.startswith("0x"):
continue
func_rva = int(insn.op_str, 16)
func_addr = self.convert_rva_to_ida(func_rva)
if func_addr == self.decryption_function_addr:
print(
f"Instruction {i}: {insn.mnemonic} {insn.op_str} at {hex(insn_addr)} calls {hex(func_addr)}"
)
points_of_interest.append(i)
return points_of_interest
def find_decryption_function(self) -> int:
encryption_key = None
decryption_function_addr = None
encryption_key_patterns = [
"lea rdx",
"lea rsi",
"mov r8d, (0x.*?)$",
"mov",
"call"
]
EncryptionKeyInstance = namedtuple("EncryptionKeyInstance", ["index", "string_length"])
pattern_instances = list(find_pattern(self.instructions, encryption_key_patterns))
pattern_instances = [
EncryptionKeyInstance(index=x, string_length=int(y, 16)) for x, y in pattern_instances
]
pattern_instances = sorted(pattern_instances, key=lambda x: x.string_length)
largest_b64_string_instance = pattern_instances[-1]
encryption_key_lea_insn = self.instructions[largest_b64_string_instance.index]
encryption_key_absolute_addr = self.convert_rip_plus_offset_to_absolute(encryption_key_lea_insn)
encryption_key = self.data[encryption_key_absolute_addr : encryption_key_absolute_addr + largest_b64_string_instance.string_length]
if not encryption_key:
raise RuntimeError("Could not find the encryption key")
encryption_key = encryption_key.decode(encoding="ascii", errors="ignore")
print(encryption_key)
self.encryption_key = encryption_key
# Find global memory location of the encryption key
decryption_key_global_patterns = [
r"lea rsi, \[rip \+ 0x.*?\]",
]
lookforward_buffer = list(
self.instructions[largest_b64_string_instance[0] : largest_b64_string_instance[0] + 1024]
)
for i in find_pattern(lookforward_buffer, decryption_key_global_patterns):
decryption_key_addr = self.convert_rip_plus_offset_to_ida(lookforward_buffer[i])
print(hex(decryption_key_addr))
break
# Find the decryption function
decryption_function_patterns = [
r"mov r9, qword ptr \[rip \+ 0x.*?\]",
]
for i in find_pattern(self.instructions, decryption_function_patterns):
encrypted_string_addr = self.convert_rip_plus_offset_to_ida(self.instructions[i])
if encrypted_string_addr == decryption_key_addr:
# Next call should be the decryption function
for j, insn in enumerate(self.instructions[i:], start=i):
if insn.mnemonic == "call":
decryption_function_addr = self.convert_rva_to_ida(int(insn.op_str, 16))
break
if not decryption_function_addr:
raise RuntimeError("Could not find the decryption function address")
self.decryption_function_addr = decryption_function_addr
def decrypt_string(self, point_of_interest: int) -> Tuple[int | None, str]:
lookback_buffer = list(
reversed(self.instructions[point_of_interest - 1024 : point_of_interest])
)
encrypted_string_length = None
encrypted_string_addr = None
decrypted_string_addr = None
# Walk backwards to the most recent mov r8d, imm instruction
# Alternatively, a pop r8 and a push imm instruction
for i, insn in enumerate(lookback_buffer):
if insn.mnemonic == "mov" and insn.op_str.startswith("r8d, 0x"):
encrypted_string_length = int(insn.op_str.split(" ")[1], 16)
break
elif (
insn.mnemonic == "pop"
and insn.op_str.startswith("r8")
and lookback_buffer[i + 1].mnemonic == "push"
):
encrypted_string_length = int(lookback_buffer[i + 1].op_str, 16)
break
# Walk backwards to the string offset
for j, insn in enumerate(lookback_buffer[i : i + 10], start=i):
if insn.mnemonic == "lea" and insn.op_str.startswith("rdx, [rip + "):
encrypted_string_addr = self.convert_rip_plus_offset_to_absolute(insn)
break
if not encrypted_string_addr or not encrypted_string_length:
print(
"Could not find the encrypted string address or length",
hex(self.convert_rva_to_ida(lookback_buffer[0].address)),
)
return None, None
# Read the encrypted string from the binary
encrypted_string = self.data[
encrypted_string_addr : encrypted_string_addr + encrypted_string_length
].decode(encoding="ascii", errors="ignore")
if not encrypted_string.isascii() or not encrypted_string.isprintable():
print("The encrypted string is not printable ASCII")
return None, None
# Decrypt the string using the hard-coded key
key = base64.b64decode(
self.encryption_key
)
encrypted_data = base64.b64decode(encrypted_string)
decrypted_string = "".join([chr(x ^ y) for x, y in zip(encrypted_data, key)])
# Find the global memory location of the decrypted string
lookforward_buffer = list(
self.instructions[point_of_interest : point_of_interest + 1024]
)
# xor reg, reg
# lea rsi, rip + offset
# cmp reg, reg2
for i, insn in enumerate(lookforward_buffer):
if insn.mnemonic == "xor":
if (
lookforward_buffer[i + 1].mnemonic == "lea"
and lookforward_buffer[i + 2].mnemonic == "cmp"
):
decrypted_string_addr = self.convert_rip_plus_offset_to_ida(
lookforward_buffer[i + 1]
)
break
# Include image base
if decrypted_string_addr:
print(
f"Decrypted string at {hex(decrypted_string_addr)}: {decrypted_string}"
)
return decrypted_string_addr, decrypted_string
else:
print(
f"Decrypted local string at {hex(lookback_buffer[0].address + self.section.VirtualAddress + self.pe.OPTIONAL_HEADER.ImageBase)}: {decrypted_string}"
)
return None, decrypted_string
with open(
r"E:\2024-03-10 CargoBay\a963a8a8e1583081daa43638744eef6c410d1a410c11eb9413da15a26e802de5",
"rb",
) as sample_f:
sample_b = sample_f.read()
decryptor = CargoBayStringDecryptor(data=sample_b, section_name=".text")
points_of_interest = decryptor.find_points_of_interest()
data = {}
for i in points_of_interest:
addr_optional, value = decryptor.decrypt_string(i)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment