Skip to content

Instantly share code, notes, and snippets.

@ntrrgc
Created June 30, 2018 15:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ntrrgc/859476b1670b196f5e0606f092276da3 to your computer and use it in GitHub Desktop.
Save ntrrgc/859476b1670b196f5e0606f092276da3 to your computer and use it in GitHub Desktop.
Proof of concept for page-based watchpoints
#define __USE_POSIX199309
#include <cassert>
#include <signal.h>
#include <mutex>
#include <functional>
#include <sys/ptrace.h>
#include <malloc.h>
#include <sys/mman.h>
#include <unistd.h>
#include <set>
#include <functional>
using namespace std;
const size_t PAGESIZE = sysconf(_SC_PAGESIZE);
class PointerRange {
public:
void* start;
void* end;
std::function<void()> onAccess;
bool operator<(const PointerRange& other) const {
return start < other.start;
}
size_t size() const {
return (char*)end - (char*)start;
}
};
class PointerRangeList: public std::set<PointerRange> {
public:
iterator findContainingPointer(void* pointer) {
for (iterator i = begin(); i != end(); ++i) {
const PointerRange& range = *i;
if (range.start <= pointer && pointer < range.end)
return i;
}
return end();
}
bool contains(void* pointer) const {
for (const PointerRange& range : *this) {
if (range.start <= pointer && pointer < range.end)
return true;
}
return false;
}
};
class MemoryUsageWatcher {
public:
static void initialize() {
assert(!s_instance);
s_instance = new MemoryUsageWatcher();
}
static MemoryUsageWatcher& instance() {
assert(s_instance);
return *s_instance;
}
// Called only from patrol thread.
void watchRange(void* _start, size_t size, function<void()> onAccess) {
char* start = (char*) _start;
// TODO ensure page-aligned
{
lock_guard<mutex> lock(mutex);
// No intersections:
assert(!m_watchedPages.contains(start));
assert(!m_watchedPages.contains(start + size));
PointerRange allocationRange { start, start + size, onAccess };
m_watchedPages.insert(allocationRange);
if (0 != mprotect(start, size, PROT_NONE)) {
perror("watchRange: ");
abort();
}
}
}
// Called only from patrol thread.
void removeWatch(void* start) {
lock_guard<mutex> lock(mutex);
PointerRangeList::iterator rangeIter = m_watchedPages.findContainingPointer(start);
if (rangeIter != m_watchedPages.end()) {
if (0 != mprotect(rangeIter->start, rangeIter->size(), PROT_READ | PROT_WRITE)) {
perror("removeWatch: ");
abort();
}
m_watchedPages.erase(rangeIter);
}
}
private:
static MemoryUsageWatcher* s_instance;
mutex m_mutex;
PointerRangeList m_watchedPages;
struct sigaction oldSigAction;
MemoryUsageWatcher() {
struct sigaction newSigAction;
newSigAction.sa_sigaction = MemoryUsageWatcher::segfaultHandlerWrapper;
newSigAction.sa_flags = SA_SIGINFO | SA_NODEFER;
// If our segfault handler has a bug, we want to catch it as usual,
// but otherwise we want no signals to interrupt the signal handler.
sigfillset(&newSigAction.sa_mask);
sigdelset(&newSigAction.sa_mask, SIGILL);
sigdelset(&newSigAction.sa_mask, SIGBUS);
sigdelset(&newSigAction.sa_mask, SIGFPE);
sigdelset(&newSigAction.sa_mask, SIGSEGV);
sigdelset(&newSigAction.sa_mask, SIGPIPE);
sigdelset(&newSigAction.sa_mask, SIGSTKFLT);
if (0 != sigaction(SIGSEGV, &newSigAction, &oldSigAction)) {
perror("MemoryUsageWatcher: could not set up signal handler: ");
abort();
}
}
public:
void segfaultHandler(void* accessedAddress) {
static thread_local bool insideSegfaultHandler = false;
static thread_local void* accessedAddressParentHandler = nullptr;
if (insideSegfaultHandler) {
// Segfault on segfault, this either is caused by:
if (accessedAddress != accessedAddressParentHandler) {
// a) A bug in this signal handler, who accessed an invalid pointer accidentally.
static const char msg[] = "MemoryUsageWatcher: Internal segmentation fault\n";
write(STDERR_FILENO, msg, sizeof(msg));
} else {
// b) A bug in the application, who accessed an invalid pointer accidentally,
// reaching this handler, who after ensuring that the pointer was not covered
// but a watched, mprotect()'ed range, decided to access it to check whether
// it was because the region was already unprotected by another thread that
// got the lock just before us or it was in fact invalid memory and turned out
// to be the latter.
}
// Either way, we have to abort for real.
sigaction(SIGSEGV, &oldSigAction, nullptr);
raise(SIGSEGV);
}
insideSegfaultHandler = true;
accessedAddressParentHandler = accessedAddress;
// TODO Check if we inside of real malloc in this thread. If that's the
// case, abort immediately before more damage is done (e.g. by the code
// following, which may use malloc()/free().
// The watched pages table can't be read and modified at the same time.
// Also, if two threads access the same page simultaneously, this
// ensures that only one executes the `onAccess()` callback.
lock_guard<mutex> lock(mutex);
PointerRangeList::iterator rangeIter = m_watchedPages.findContainingPointer(accessedAddress);
if (rangeIter != m_watchedPages.end()) {
if (0 != mprotect(rangeIter->start, rangeIter->size(), PROT_READ | PROT_WRITE)) {
perror("segfaultHandler: ");
abort();
}
rangeIter->onAccess();
m_watchedPages.erase(rangeIter);
} else {
// The address is not in the table of watched ranges. It may have
// been removed by another thread that acquired the lock before us,
// or maybe it's just a buggy pointer from the application.
// How can we know? Just access the pointer. In the former case, it
// will do nothing, in the latter, it will segfault again.
volatile char *pointer = (char*) accessedAddress;
*pointer;
}
insideSegfaultHandler = false;
}
static void segfaultHandlerWrapper(int signum, siginfo_t* siginfo, void*) {
instance().segfaultHandler(siginfo->si_addr);
}
};
MemoryUsageWatcher* MemoryUsageWatcher::s_instance = nullptr;
struct Potato {
int x = 10;
char y = 1;
char z = 5;
};
int main(int argc, char** argv) {
printf("Page size: %zu\n", PAGESIZE);
MemoryUsageWatcher::initialize();
Potato* p = (Potato*) pvalloc(sizeof(Potato));
printf("Allocated %p\n", p);
new(p) Potato;
MemoryUsageWatcher::instance().watchRange(p, PAGESIZE, []() {
printf("Potato access detected!\n");
});
printf("p->y = %d\n", p->y);
printf("p->x = %d\n", p->x);
MemoryUsageWatcher::instance().watchRange(p, PAGESIZE, []() {
printf("Second potato access detected!\n");
});
printf("p->z = %d\n", p->z);
printf("p->y = %d\n", p->y);
p->~Potato();
free(p);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment