Skip to content

Instantly share code, notes, and snippets.

@JohnScience
Created April 4, 2024 06:33
Show Gist options
  • Save JohnScience/cbb5a59453c2b784c9e3b6d31691068b to your computer and use it in GitHub Desktop.
Save JohnScience/cbb5a59453c2b784c9e3b6d31691068b to your computer and use it in GitHub Desktop.
An implementation of a reflective PE loader with a bit of hardcode for an example dll. It is meant to be librarified soon
use core::panic;
use core::ptr;
use goblin::pe::data_directories::DataDirectory;
use goblin::pe::section_table::{IMAGE_SCN_MEM_EXECUTE, IMAGE_SCN_MEM_READ, IMAGE_SCN_MEM_WRITE};
use goblin::pe::PE;
use winapi::ctypes::c_void;
use winapi::shared::minwindef::HINSTANCE;
use winapi::um::memoryapi::VirtualAlloc;
use winapi::um::winnt::{
DLL_THREAD_ATTACH, MEM_COMMIT, MEM_RESERVE, PAGE_EXECUTE, PAGE_EXECUTE_READ,
PAGE_EXECUTE_READWRITE, PAGE_EXECUTE_WRITECOPY, PAGE_NOACCESS, PAGE_READONLY, PAGE_READWRITE,
};
mod windows {
use core::mem::MaybeUninit;
use winapi::ctypes::c_void;
use winapi::shared::minwindef::{BOOL, DWORD, HINSTANCE, LPVOID};
use winapi::um::winnt::MEMORY_BASIC_INFORMATION;
pub(crate) type DllEntryProc =
unsafe extern "system" fn(hinstDLL: HINSTANCE, fdwReason: DWORD, lpReserved: LPVOID) -> BOOL;
pub(crate) const IMAGE_SIZEOF_BASE_RELOCATION: usize = 8;
#[derive(Debug, thiserror::Error)]
#[error("Unsupported relocation type: {0}")]
pub(crate) struct UnsupportedRelocationType(u8);
#[derive(Debug, thiserror::Error)]
pub(crate) enum VirtualQueryError {
#[error("VirtualQuery failed: {0}")]
IoError(#[from] std::io::Error),
#[error("Written unexpected number of bytes: {0}")]
UnexpectedBytesCountWritten(usize),
}
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum BaseRelocationType {
#[doc(alias = "IMAGE_REL_BASED_ABSOLUTE")]
Absolute = 0,
#[doc(alias = "IMAGE_REL_BASED_HIGHLOW")]
HighLow = 3,
#[doc(alias = "IMAGE_REL_BASED_DIR64")]
Dir64 = 10,
}
#[repr(transparent)]
pub(crate) struct BaseRelocationEntry(u16);
#[doc(alias = "IMAGE_BASE_RELOCATION")]
#[derive(Debug)]
#[repr(C)]
pub(crate) struct BaseRelocationBlock {
pub(crate) virtual_address: u32,
pub(crate) size_of_block: u32,
}
impl std::fmt::Debug for BaseRelocationEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(BaseRelocationEntry))
.field(
"base_relocation_type_nibble",
&self.base_relocation_type_nibble(),
)
// .field("base_relocation_type", &self.base_relocation_type())
.field("va_offset", &self.va_offset())
.finish()
}
}
impl BaseRelocationEntry {
fn base_relocation_type_nibble(&self) -> u8 {
(self.0 >> 12) as u8
}
pub(crate) fn base_relocation_type(
&self,
) -> Result<BaseRelocationType, UnsupportedRelocationType> {
let base_relocation_type_nibble = self.base_relocation_type_nibble();
let base_relocation_type = match base_relocation_type_nibble {
0 => BaseRelocationType::Absolute,
3 => BaseRelocationType::HighLow,
10 => BaseRelocationType::Dir64,
_ => return Err(UnsupportedRelocationType(base_relocation_type_nibble)),
};
Ok(base_relocation_type)
}
// The offset from the virtal address of the IMAGE_BASE_RELOCATION structure
pub(crate) fn va_offset(&self) -> u16 {
self.0 & 0x0FFF
}
pub(crate) fn perform_single_relocation(&self, dest: *mut c_void, delta: isize) {
let relocation_type = self.base_relocation_type().unwrap();
let offset = self.va_offset();
match relocation_type {
BaseRelocationType::Absolute => {
// Skip
}
BaseRelocationType::HighLow => {
let dest = unsafe { dest.byte_add(offset as usize) } as *mut u32;
let value = unsafe { dest.read() };
unsafe {
dest.write(value.wrapping_add(delta as u32));
}
}
BaseRelocationType::Dir64 => {
let dest = unsafe { dest.byte_add(offset as usize) } as *mut u64;
let value = unsafe { dest.read() };
unsafe {
dest.write(value.wrapping_add(delta as u64));
}
}
}
}
}
/// Converts a relative virtual address (RVA) to a virtual address (VA) in the image.
pub(crate) fn rva_to_va(image_base: *mut c_void, rva: u32) -> *mut c_void {
(image_base as usize + rva as usize) as *mut c_void
}
pub(crate) fn section_size(scn: &goblin::pe::section_table::SectionTable) -> usize {
scn.size_of_raw_data as usize
}
pub(crate) fn section_file_ptr_range(
scn: &goblin::pe::section_table::SectionTable,
) -> std::ops::Range<usize> {
scn.pointer_to_raw_data as usize..scn.pointer_to_raw_data as usize + section_size(scn) as usize
}
pub(crate) fn section_va_range(
image_base: *mut c_void,
scn: &goblin::pe::section_table::SectionTable,
) -> std::ops::Range<usize> {
rva_to_va(image_base, scn.virtual_address) as usize
..rva_to_va(image_base, scn.virtual_address) as usize + section_size(scn)
}
pub(crate) fn virtual_protect(
address: *mut c_void,
size: usize,
perms: u32, // old perms is not used
) -> Result<(), std::io::Error> {
let mut old_perms = 0;
let result =
unsafe { winapi::um::memoryapi::VirtualProtect(address, size, perms, &mut old_perms) };
if result == 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
pub(crate) fn virtual_query(
ptr: *const c_void,
) -> Result<MEMORY_BASIC_INFORMATION, VirtualQueryError> {
let mut info: MaybeUninit<MEMORY_BASIC_INFORMATION> = MaybeUninit::uninit();
let bytes_written = unsafe {
winapi::um::memoryapi::VirtualQuery(
ptr,
info.as_mut_ptr(),
core::mem::size_of::<MEMORY_BASIC_INFORMATION>(),
)
};
if bytes_written == 0 {
return Err(VirtualQueryError::IoError(std::io::Error::last_os_error()));
};
// technically, this should be an error too
if bytes_written != core::mem::size_of::<MEMORY_BASIC_INFORMATION>() {
return Err(VirtualQueryError::UnexpectedBytesCountWritten(
bytes_written,
));
}
let info = unsafe { info.assume_init() };
Ok(info)
}
}
use windows::*;
macro_rules! anonymous_fn_ptr {
($i:ident) => {
&{
let s = format!("fn_ptr#{}", $i);
core::ops::AddAssign::add_assign(&mut $i, 1);
s
}
};
}
// struct PeDll {
// image_base: *mut c_void,
// }
// #[derive(Debug, thiserror::Error)]
// enum PeDllLoadError {
// #[error("PE parsing error: {0}")]
// PeParsingError(#[from] goblin::error::Error),
// #[error("The file is in PE format but is not a DLL")]
// PeButNotDll,
// #[error("Memory allocation error. {0}")]
// MemoryAllocationError(MemoryAllocationError),
// }
// #[derive(Debug, thiserror::Error)]
// #[error("Failed to allocate memory for {dll_name}: {inner}.")]
// struct MemoryAllocationError {
// inner: std::io::Error,
// dll_name: String,
// }
// impl PeDll {
// // Allocate memory for the image
// fn allocate_memory(pe: &PE) -> Result<ptr::NonNull<c_void>, MemoryAllocationError> {
// let image_size: usize = pe
// .header
// .optional_header
// .unwrap()
// .windows_fields
// .size_of_image as usize;
// let preferred_base: *mut c_void = {
// let preferred_base: u64 = pe.header.optional_header.unwrap().windows_fields.image_base;
// preferred_base as *mut c_void
// };
// let image_base: *mut c_void =
// unsafe { VirtualAlloc(preferred_base, image_size, MEM_RESERVE, PAGE_READWRITE) };
// match ptr::NonNull::new(image_base) {
// Some(image_base) => Ok(image_base),
// None => Err(MemoryAllocationError {
// inner: std::io::Error::last_os_error(),
// dll_name: pe.name.unwrap_or_else(|| "the DLL").to_string(),
// }),
// }
// }
// pub fn new(bytes: &[u8]) -> Result<Self, PeDllLoadError> {
// let pe = PE::parse(&bytes)?;
// if !pe.is_lib {
// return Err(PeDllLoadError::PeButNotDll);
// }
// todo!()
// }
// }
// The code below is written based on <https://www.joachim-bauch.de/tutorials/loading-a-dll-from-memory/>
// Note (JohnScience): the code assumes that the DLL is already in memory, so why not change the memory
// protections appropriately and modify the DLL in memory?
// See https://github.com/HotKeyIt/ahkdll/blob/818386f5af7e6000d945801838d4e80a9e530c0d/source/MemoryModule.cpp#L476
fn perform_base_relocation(pe: &PE, image_base: *mut c_void, delta: isize) {
let DataDirectory {
virtual_address: base_relocation_table_rva,
// It is unused because we rely on the sentinel value of the size_of_block field
size: _base_relocation_table_size,
} = pe
.header
.optional_header
.unwrap()
.data_directories
.get_base_relocation_table()
.unwrap()
.clone();
let mut base_relocation_block_ptr: *mut BaseRelocationBlock =
rva_to_va(image_base, base_relocation_table_rva) as *mut BaseRelocationBlock;
loop {
let base_relocation_block = unsafe { base_relocation_block_ptr.read() };
if base_relocation_block.size_of_block == 0 {
break;
}
let dest = rva_to_va(image_base, base_relocation_block.virtual_address);
let mut rel_info =
unsafe { base_relocation_block_ptr.byte_add(IMAGE_SIZEOF_BASE_RELOCATION) }
as *mut BaseRelocationEntry;
let mut i = 0;
let rel_count =
(base_relocation_block.size_of_block as usize - IMAGE_SIZEOF_BASE_RELOCATION) / 2;
while i < rel_count {
let rel_entry = unsafe { rel_info.read() };
rel_entry.perform_single_relocation(dest, delta);
i += 1;
rel_info = unsafe { rel_info.add(1) };
}
// Q: do we really need to flush the instruction cache here as done in
// https://github.com/HotKeyIt/ahkdll/blob/818386f5af7e6000d945801838d4e80a9e530c0d/source/MemoryModule.cpp#L527
// FlushInstructionCache(GetCurrentProcess(), dest, module->pageSize);
base_relocation_block_ptr = unsafe {
base_relocation_block_ptr.byte_add(base_relocation_block.size_of_block as usize)
};
}
}
fn copy_sections(pe: &PE, image_base: *mut c_void, bytes: &[u8]) {
for section in pe.sections.iter() {
let section_size: usize = section_size(section);
if section_size > 0 {
let va_range = section_va_range(image_base, section);
let section_dest: *mut u8 = unsafe {
// The call below commits already reserved memory
VirtualAlloc(
va_range.start as *mut c_void,
section_size,
MEM_COMMIT,
PAGE_READWRITE,
)
} as *mut u8;
assert!(section_dest as usize == va_range.start);
let file_ptr_range = section_file_ptr_range(section);
let section_data: &[u8] = &bytes[file_ptr_range.start..file_ptr_range.end];
println!(
"Copying section {:?} (size: {:x}h) to {:p} through {:p} (RVAs {:x}h to {:x}h)",
section.name().unwrap(),
section_size,
section_dest,
va_range.end as *mut u8,
file_ptr_range.start,
file_ptr_range.end
);
unsafe {
core::ptr::copy_nonoverlapping(section_data.as_ptr(), section_dest, section_size);
}
}
}
}
// TODO: learn more about implicit and explicit linking, especially about delayed-loaded DLL option for implicit linking
fn resolve_imports(pe: &PE, image_base: *mut c_void) {
for synthetic_import_directory_entry in pe.import_data.as_ref().unwrap().import_data.iter() {
let import_directory_entry = &synthetic_import_directory_entry.import_directory_entry;
//let name: &str = {
// let name_va: *const i8 =
// rva_to_va(image_base, import_directory_entry.name_rva) as *const i8;
// unsafe { core::ffi::CStr::from_ptr(name_va).to_str().unwrap() }
//};
println!("Importing: {}", synthetic_import_directory_entry.name);
let import_lookup_table: *mut u32 = {
let rva = import_directory_entry.import_lookup_table_rva;
rva_to_va(image_base, rva) as *mut u32
};
let import_address_table: *mut *mut c_void =
rva_to_va(image_base, import_directory_entry.import_address_table_rva)
as *mut *mut c_void;
let mut i = 0;
loop {
let import_lookup_entry = unsafe { import_lookup_table.add(i).read() };
if import_lookup_entry == 0 {
break;
}
let import_name: *const i8 =
rva_to_va(image_base, import_lookup_entry as u32 + 2) as *const i8;
let import_name: &str =
unsafe { core::ffi::CStr::from_ptr(import_name).to_str().unwrap() };
println!(" Importing: {}", import_name);
let import_address: *mut c_void = unsafe {
winapi::um::libloaderapi::GetProcAddress(
image_base as HINSTANCE,
import_name.as_ptr() as *const i8,
)
} as *mut c_void;
unsafe {
*import_address_table.add(i) = import_address;
}
i += 1;
}
}
}
fn protect_memory(pe: &PE, image_base: *mut c_void) {
// We sort the sections by their privileges to avoid depriving pages of their privileges.
// Even though sections themselves do not overlap, their pages might.
let mut sections: Vec<_> = pe
.sections
.iter()
.filter(|section| section_size(section) > 0)
.collect();
sections.sort_by(|a, b| {
// pv stands for "privilege value"
let [pv_a, pv_b] = [a, b]
.map(|scn| {
let r: bool = scn.characteristics & IMAGE_SCN_MEM_READ != 0;
let w: bool = scn.characteristics & IMAGE_SCN_MEM_WRITE != 0;
let e: bool = scn.characteristics & IMAGE_SCN_MEM_EXECUTE != 0;
[r, w, e]
})
.map(|privileges| {
privileges
.iter()
.enumerate()
.map(|(i, &p)| if p { 1 << (i + 1) } else { 0 })
.sum::<u8>()
});
pv_a.cmp(&pv_b)
});
for section in sections.iter() {
let r: bool = section.characteristics & IMAGE_SCN_MEM_READ != 0;
let w: bool = section.characteristics & IMAGE_SCN_MEM_WRITE != 0;
let e: bool = section.characteristics & IMAGE_SCN_MEM_EXECUTE != 0;
let section_size = section_size(section);
let scn_name = section.name().unwrap();
println!("Characteristics: {:x}", section.characteristics);
println!(
"Protecting section {:?} (size: {:x}h) with Read: {}, Write: {}, Execute: {}",
scn_name, section_size, r, w, e
);
// TODO: account for other section characteristics
let perms = match (r, w, e) {
(false, false, false) => PAGE_NOACCESS,
(false, false, true) => PAGE_EXECUTE,
(false, true, false) => panic!("Invalid section permissions"),
(false, true, true) => PAGE_EXECUTE_WRITECOPY,
(true, false, false) => PAGE_READONLY,
(true, false, true) => PAGE_EXECUTE_READ,
(true, true, false) => PAGE_READWRITE,
(true, true, true) => PAGE_EXECUTE_READWRITE,
};
let section_va_range = section_va_range(image_base, section);
let section_va = rva_to_va(image_base, section.virtual_address);
windows::virtual_protect(section_va, section_size, perms).unwrap();
// TODO: remove once the code is stable
if cfg!(debug_assertions) && scn_name == ".text" {
check_memory_protection(section_va as *const c_void);
check_memory_protection(0x1800170c0 as *const c_void);
check_memory_protection(section_va_range.end as *const c_void);
}
}
}
fn check_memory_protection(ptr: *const c_void) {
let info = virtual_query(ptr).unwrap();
let protection = match info.Protect {
PAGE_EXECUTE_READ => stringify!(PAGE_EXECUTE_READ),
PAGE_EXECUTE_READWRITE => stringify!(PAGE_EXECUTE_READWRITE),
PAGE_EXECUTE_WRITECOPY => stringify!(PAGE_EXECUTE_WRITECOPY),
PAGE_READONLY => stringify!(PAGE_READONLY),
PAGE_READWRITE => stringify!(PAGE_READWRITE),
_ => panic!("Unknown memory protection"),
};
print!(
"Memory protection at {ptr:p}: {protection} ({:x}). Base address: {:p}. Size: {:x}h\n\n",
info.Protect, info.BaseAddress, info.RegionSize,
);
}
fn print_fn_ptr_disassembly(
pe: &PE,
image_base: *mut c_void,
fn_ptr: *const c_void,
name: &str,
comments: &'static [(usize, &'static str)],
) {
let scn = pe
.sections
.iter()
.find(|scn| section_va_range(image_base, scn).contains(&(fn_ptr as usize)))
.unwrap();
println!("Disassembly for memory at {fn_ptr:p} ({name}) function pointer located in {scn_name} section:",
scn_name = scn.name().unwrap(),
);
let scn_va_range = section_va_range(image_base, scn);
let possible_machine_code_range_size = scn_va_range.end - fn_ptr as usize;
let machine_code = unsafe {
std::slice::from_raw_parts(fn_ptr as *const u8, possible_machine_code_range_size)
};
// println!("Machine code: {:x?}", machine_code);
let mut decoder = iced_x86::Decoder::new(64, machine_code, iced_x86::DecoderOptions::NONE);
let mut offset = 0;
loop {
let instr = decoder.decode();
if instr.is_invalid() {
panic!("Invalid instruction");
}
let ptr = machine_code.as_ptr().wrapping_offset(offset as isize);
if let Some(comment) = comments
.iter()
.find(|(comment_offset, _)| *comment_offset == offset)
.map(|(_, comment)| *comment)
{
println!("\t; {comment}");
}
for op_kind in instr.op_kinds() {
match op_kind {
iced_x86::OpKind::NearBranch16
| iced_x86::OpKind::NearBranch32
| iced_x86::OpKind::NearBranch64 => {
let target = (fn_ptr as u64).wrapping_add(instr.near_branch_target());
println!("\t; Target: {:#x}", target);
}
_ => (),
}
}
if instr.code() == iced_x86::Code::Call_rm64 {
let target = (fn_ptr as u64).wrapping_add(instr.memory_displacement64());
println!("\t; Target: {:#x}", target);
}
println!("\t{ptr:p} (offset {offset}): {instr}");
offset += instr.len();
if instr.mnemonic() == iced_x86::Mnemonic::Ret {
break;
}
}
println!();
}
fn notify_dll(pe: &PE, image_base: *mut c_void) {
let entrypoint: DllEntryProc = {
let address_of_entrypoint = pe
.header
.optional_header
.unwrap()
.standard_fields
.address_of_entry_point;
if address_of_entrypoint == 0 {
println!("The DLL has no entrypoint");
return;
}
let entrypoint_va: *const c_void =
(image_base as usize + address_of_entrypoint as usize) as *const c_void;
println!("Entrypoint VA: {:p}", entrypoint_va);
unsafe { core::mem::transmute(entrypoint_va) }
};
check_memory_protection(entrypoint as *const c_void);
print_fn_ptr_disassembly(
pe,
image_base,
entrypoint as *const c_void,
"Entrypoint",
&[(26, "The jump doesn't happen")],
);
assert!(image_base as usize == 0x0000000180000000);
assert!(pe.name == Some("hello_world_lib.dll"));
let mut fn_ptr_i = 0;
// TODO: follow calls to the function pointers automatically
print_fn_ptr_disassembly(
pe,
image_base,
0x180014e8c as *const c_void,
anonymous_fn_ptr!(fn_ptr_i),
&[(33, "The jump doesn't happen")],
);
// print_fn_ptr_disassembly(
// pe,
// image_base,
// 0x1800170c0 as *const c_void,
// anonymous_fn_ptr!(fn_ptr_i),
// &[],
// );
println!("Calling entrypoint at {:p}", entrypoint);
// https://learn.microsoft.com/en-us/windows/win32/dlls/dllmain#parameters
// https://learn.microsoft.com/en-us/windows/win32/dlls/dllmain
// DLL_PROCESS_ATTACH doesn't actually work
unsafe {
entrypoint(
image_base as HINSTANCE,
DLL_THREAD_ATTACH,
std::ptr::null_mut(),
);
}
}
fn main() {
let bytes = match std::fs::read("../hello_world_lib/target/release/hello_world_lib.dll") {
Ok(bytes) => bytes,
Err(_) => {
println!("Enter the path to the DLL file: ");
let mut input: String = String::new();
std::io::stdin().read_line(&mut input).unwrap();
std::fs::read(&input).unwrap()
}
};
let pe = PE::parse(&bytes).unwrap();
// println!("{:?}", pe);
// Allocate memory for the image
let image_size: usize = pe
.header
.optional_header
.unwrap()
.windows_fields
.size_of_image as usize;
let preferred_base: *mut c_void = {
let preferred_base: u64 = pe.header.optional_header.unwrap().windows_fields.image_base;
preferred_base as *mut c_void
};
let image_base: *mut c_void =
unsafe { VirtualAlloc(preferred_base, image_size, MEM_RESERVE, PAGE_READWRITE) };
println!();
copy_sections(&pe, image_base, &bytes);
println!();
let delta: isize =
image_base as isize - pe.header.optional_header.unwrap().windows_fields.image_base as isize;
perform_base_relocation(&pe, image_base, delta);
println!("Relocated image to {image_base:p} with delta {delta:x}");
print!("\nEntering import resolution phase\n");
// Resolve imports
resolve_imports(&pe, image_base);
println!();
// protect memory
protect_memory(&pe, image_base);
println!();
// call entrypoint (= DLLMain) to notify the DLL that it has been loaded
notify_dll(&pe, image_base);
println!("DLL loaded successfully");
for export in pe.exports.iter() {
println!("{:?}", export);
}
let add = pe
.exports
.iter()
.find(|export| export.name == Some("add"))
.unwrap();
let add = (image_base as usize + add.rva as usize) as *const c_void;
let add: extern "C" fn(i32, i32) -> i32 = unsafe { core::mem::transmute(add) };
println!("add(1, 2) = {}", add(1, 2));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment