Skip to content

Instantly share code, notes, and snippets.

@ttsuki
Last active June 11, 2024 05:28
Show Gist options
  • Save ttsuki/a55f32d0358c4dc55d345eec994a037c to your computer and use it in GitHub Desktop.
Save ttsuki/a55f32d0358c4dc55d345eec994a037c to your computer and use it in GitHub Desktop.
com_trace.h PoC
/// @file
/// @brief xtw::com_trace - object tracer with vtable injection
/// @author (C) 2024 ttsuki
/// Distributed under the Boost Software License, Version 1.0.
#pragma once
#include <Windows.h>
#include <combaseapi.h>
#include <cassert>
#include <type_traits>
#include <memory>
#include <functional>
#include <mutex>
#include <unordered_map>
#include <utility>
namespace xtw::com_trace
{
namespace detail
{
struct IUnknown_vtable
{
HRESULT (STDMETHODCALLTYPE*QueryInterface)(void* this_, REFIID riid, void** ppvObject);
ULONG (STDMETHODCALLTYPE*AddRef)(void* this_);
ULONG (STDMETHODCALLTYPE*Release)(void* this_);
};
static inline const IUnknown_vtable* get_vtable_from_object(IUnknown* object)
{
return *reinterpret_cast<IUnknown_vtable**>(object);
}
static inline bool rewrite_vtable(IUnknown_vtable* target, IUnknown_vtable replace_with)
{
DWORD original_protect = {};
if (::VirtualProtect(target, sizeof(IUnknown_vtable), PAGE_EXECUTE_READWRITE, &original_protect))
{
target->QueryInterface = replace_with.QueryInterface;
target->AddRef = replace_with.AddRef;
target->Release = replace_with.Release;
::VirtualProtect(target, sizeof(IUnknown_vtable), original_protect, &original_protect);
::FlushInstructionCache(::GetCurrentProcess(), nullptr, 0);
return true;
}
return false;
}
struct linked_data
{
std::function<void()> release_callback_function;
};
template <class instance_tag>
struct hook_table
{
static void trace_object(IUnknown* obj, linked_data data)
{
init();
install_hook(obj, std::make_shared<linked_data>(std::move(data)));
}
private:
static inline std::recursive_mutex* mutex_{};
static inline std::unordered_map<const IUnknown_vtable*, IUnknown_vtable>* original_vtable_{};
static inline std::unordered_map<IUnknown*, std::shared_ptr<linked_data>>* link_data_forward_{};
static inline std::unordered_multimap<linked_data*, IUnknown*>* link_data_reverse_{};
static void init()
{
static auto i = []
{
mutex_ = new std::remove_pointer_t<decltype(mutex_)>();
original_vtable_ = new std::remove_pointer_t<decltype(original_vtable_)>();
link_data_forward_ = new std::remove_pointer_t<decltype(link_data_forward_)>();
link_data_reverse_ = new std::remove_pointer_t<decltype(link_data_reverse_)>();
return 1;
}();
}
static void install_hook(IUnknown* target, std::shared_ptr<linked_data> data)
{
std::lock_guard lock(*mutex_);
// Link data to IUnknown pointer.
if (target && data && !get_linked_data(target))
{
link_data_forward_->emplace(target, data);
link_data_reverse_->emplace(data.get(), target);
}
// Install IUnknown vtable hook.
auto vp = get_vtable_from_object(target);
auto it = original_vtable_->find(vp);
if (it == original_vtable_->end())
{
original_vtable_->insert({vp, *vp});
detail::rewrite_vtable(
const_cast<IUnknown_vtable*>(vp),
{
handle_QueryInterface,
handle_AddRef,
handle_Release,
});
}
}
static IUnknown_vtable get_original_vtable(IUnknown* target)
{
std::lock_guard lock(*mutex_);
auto vp = get_vtable_from_object(target);
auto it = original_vtable_->find(vp);
if (it != original_vtable_->end())
return it->second;
return {};
}
static std::shared_ptr<linked_data> get_linked_data(IUnknown* target)
{
assert(target);
std::lock_guard lock(*mutex_);
auto it = link_data_forward_->find(target);
if (it != link_data_forward_->end())
return it->second;
return nullptr;
}
static void erase_linked_data(std::shared_ptr<linked_data> data)
{
std::lock_guard lock(*mutex_);
auto [beg, end] = link_data_reverse_->equal_range(data.get());
for (auto it = beg; it != end; ++it) { link_data_forward_->erase(it->second); }
link_data_reverse_->erase(beg, end);
}
static HRESULT STDMETHODCALLTYPE handle_QueryInterface(void* self, const IID& riid, void** ppvObject)
{
if (!ppvObject) return E_POINTER;
IUnknown* this_ = reinterpret_cast<IUnknown*>(self);
IUnknown_vtable orig = get_original_vtable(this_);
if (orig.QueryInterface)
{
HRESULT hr = orig.QueryInterface(this_, riid, ppvObject);
if (SUCCEEDED(hr))
{
IUnknown* that_ = static_cast<IUnknown*>(*ppvObject);
if (that_ && that_ != this_)
{
install_hook(that_, get_linked_data(this_));
}
}
return hr;
}
else
{
assert(false); // critical error
return E_FAIL;
}
}
static ULONG STDMETHODCALLTYPE handle_AddRef(void* self)
{
IUnknown* this_ = reinterpret_cast<IUnknown*>(self);
IUnknown_vtable orig = get_original_vtable(this_);
if (orig.AddRef)
{
ULONG r = orig.AddRef(this_);
return r;
}
else
{
assert(false); // critical error
return 0;
}
}
static ULONG STDMETHODCALLTYPE handle_Release(void* self)
{
IUnknown* this_ = reinterpret_cast<IUnknown*>(self);
IUnknown_vtable orig = get_original_vtable(this_);
if (orig.Release)
{
ULONG r = orig.Release(this_);
if (r == 0)
{
if (std::shared_ptr<linked_data> data = get_linked_data(this_))
{
erase_linked_data(data);
if (data->release_callback_function)
data->release_callback_function();
}
}
return r;
}
else
{
assert(false); // critical error
return 0;
}
}
};
}
template <class instance_tag = void>
static inline void register_release_callback(
IUnknown* obj,
std::function<void()> on_release)
{
detail::hook_table<instance_tag>::trace_object(
obj,
detail::linked_data{
std::move(on_release)
});
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment