Last active
June 11, 2024 05:28
-
-
Save ttsuki/a55f32d0358c4dc55d345eec994a037c to your computer and use it in GitHub Desktop.
com_trace.h PoC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// @file | |
/// @brief xtw::com_trace - object tracer with vtable injection | |
/// @author (C) 2024 ttsuki | |
/// Distributed under the Boost Software License, Version 1.0. | |
#pragma once | |
#include <Windows.h> | |
#include <combaseapi.h> | |
#include <cassert> | |
#include <type_traits> | |
#include <memory> | |
#include <functional> | |
#include <mutex> | |
#include <unordered_map> | |
#include <utility> | |
namespace xtw::com_trace | |
{ | |
namespace detail | |
{ | |
struct IUnknown_vtable | |
{ | |
HRESULT (STDMETHODCALLTYPE*QueryInterface)(void* this_, REFIID riid, void** ppvObject); | |
ULONG (STDMETHODCALLTYPE*AddRef)(void* this_); | |
ULONG (STDMETHODCALLTYPE*Release)(void* this_); | |
}; | |
static inline const IUnknown_vtable* get_vtable_from_object(IUnknown* object) | |
{ | |
return *reinterpret_cast<IUnknown_vtable**>(object); | |
} | |
static inline bool rewrite_vtable(IUnknown_vtable* target, IUnknown_vtable replace_with) | |
{ | |
DWORD original_protect = {}; | |
if (::VirtualProtect(target, sizeof(IUnknown_vtable), PAGE_EXECUTE_READWRITE, &original_protect)) | |
{ | |
target->QueryInterface = replace_with.QueryInterface; | |
target->AddRef = replace_with.AddRef; | |
target->Release = replace_with.Release; | |
::VirtualProtect(target, sizeof(IUnknown_vtable), original_protect, &original_protect); | |
::FlushInstructionCache(::GetCurrentProcess(), nullptr, 0); | |
return true; | |
} | |
return false; | |
} | |
struct linked_data | |
{ | |
std::function<void()> release_callback_function; | |
}; | |
template <class instance_tag> | |
struct hook_table | |
{ | |
static void trace_object(IUnknown* obj, linked_data data) | |
{ | |
init(); | |
install_hook(obj, std::make_shared<linked_data>(std::move(data))); | |
} | |
private: | |
static inline std::recursive_mutex* mutex_{}; | |
static inline std::unordered_map<const IUnknown_vtable*, IUnknown_vtable>* original_vtable_{}; | |
static inline std::unordered_map<IUnknown*, std::shared_ptr<linked_data>>* link_data_forward_{}; | |
static inline std::unordered_multimap<linked_data*, IUnknown*>* link_data_reverse_{}; | |
static void init() | |
{ | |
static auto i = [] | |
{ | |
mutex_ = new std::remove_pointer_t<decltype(mutex_)>(); | |
original_vtable_ = new std::remove_pointer_t<decltype(original_vtable_)>(); | |
link_data_forward_ = new std::remove_pointer_t<decltype(link_data_forward_)>(); | |
link_data_reverse_ = new std::remove_pointer_t<decltype(link_data_reverse_)>(); | |
return 1; | |
}(); | |
} | |
static void install_hook(IUnknown* target, std::shared_ptr<linked_data> data) | |
{ | |
std::lock_guard lock(*mutex_); | |
// Link data to IUnknown pointer. | |
if (target && data && !get_linked_data(target)) | |
{ | |
link_data_forward_->emplace(target, data); | |
link_data_reverse_->emplace(data.get(), target); | |
} | |
// Install IUnknown vtable hook. | |
auto vp = get_vtable_from_object(target); | |
auto it = original_vtable_->find(vp); | |
if (it == original_vtable_->end()) | |
{ | |
original_vtable_->insert({vp, *vp}); | |
detail::rewrite_vtable( | |
const_cast<IUnknown_vtable*>(vp), | |
{ | |
handle_QueryInterface, | |
handle_AddRef, | |
handle_Release, | |
}); | |
} | |
} | |
static IUnknown_vtable get_original_vtable(IUnknown* target) | |
{ | |
std::lock_guard lock(*mutex_); | |
auto vp = get_vtable_from_object(target); | |
auto it = original_vtable_->find(vp); | |
if (it != original_vtable_->end()) | |
return it->second; | |
return {}; | |
} | |
static std::shared_ptr<linked_data> get_linked_data(IUnknown* target) | |
{ | |
assert(target); | |
std::lock_guard lock(*mutex_); | |
auto it = link_data_forward_->find(target); | |
if (it != link_data_forward_->end()) | |
return it->second; | |
return nullptr; | |
} | |
static void erase_linked_data(std::shared_ptr<linked_data> data) | |
{ | |
std::lock_guard lock(*mutex_); | |
auto [beg, end] = link_data_reverse_->equal_range(data.get()); | |
for (auto it = beg; it != end; ++it) { link_data_forward_->erase(it->second); } | |
link_data_reverse_->erase(beg, end); | |
} | |
static HRESULT STDMETHODCALLTYPE handle_QueryInterface(void* self, const IID& riid, void** ppvObject) | |
{ | |
if (!ppvObject) return E_POINTER; | |
IUnknown* this_ = reinterpret_cast<IUnknown*>(self); | |
IUnknown_vtable orig = get_original_vtable(this_); | |
if (orig.QueryInterface) | |
{ | |
HRESULT hr = orig.QueryInterface(this_, riid, ppvObject); | |
if (SUCCEEDED(hr)) | |
{ | |
IUnknown* that_ = static_cast<IUnknown*>(*ppvObject); | |
if (that_ && that_ != this_) | |
{ | |
install_hook(that_, get_linked_data(this_)); | |
} | |
} | |
return hr; | |
} | |
else | |
{ | |
assert(false); // critical error | |
return E_FAIL; | |
} | |
} | |
static ULONG STDMETHODCALLTYPE handle_AddRef(void* self) | |
{ | |
IUnknown* this_ = reinterpret_cast<IUnknown*>(self); | |
IUnknown_vtable orig = get_original_vtable(this_); | |
if (orig.AddRef) | |
{ | |
ULONG r = orig.AddRef(this_); | |
return r; | |
} | |
else | |
{ | |
assert(false); // critical error | |
return 0; | |
} | |
} | |
static ULONG STDMETHODCALLTYPE handle_Release(void* self) | |
{ | |
IUnknown* this_ = reinterpret_cast<IUnknown*>(self); | |
IUnknown_vtable orig = get_original_vtable(this_); | |
if (orig.Release) | |
{ | |
ULONG r = orig.Release(this_); | |
if (r == 0) | |
{ | |
if (std::shared_ptr<linked_data> data = get_linked_data(this_)) | |
{ | |
erase_linked_data(data); | |
if (data->release_callback_function) | |
data->release_callback_function(); | |
} | |
} | |
return r; | |
} | |
else | |
{ | |
assert(false); // critical error | |
return 0; | |
} | |
} | |
}; | |
} | |
template <class instance_tag = void> | |
static inline void register_release_callback( | |
IUnknown* obj, | |
std::function<void()> on_release) | |
{ | |
detail::hook_table<instance_tag>::trace_object( | |
obj, | |
detail::linked_data{ | |
std::move(on_release) | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment