Skip to content

Instantly share code, notes, and snippets.

@bgaff
Created February 14, 2020 22:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bgaff/3a8b2a890ae4771be22456e014c2e5aa to your computer and use it in GitHub Desktop.
Save bgaff/3a8b2a890ae4771be22456e014c2e5aa to your computer and use it in GitHub Desktop.
Example of SIGBUS caused by race handling userfaults
#define _GNU_SOURCE
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <assert.h>
#include <string.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <time.h>
#include <poll.h>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <sys/ioctl.h>
#include <pthread.h>
#include <linux/userfaultfd.h>
static void *base;
static unsigned int nr_pages;
static unsigned int page_size;
static unsigned int nr_cpus;
static int current_pages_in_region;
int uffd;
/*
* Counting threads run one per cpu.
*
* These threads will spin doing nothing but adding one to the
* first unsigned long bytes of a randomly choosen page. This
* thread will be the source of faults and ultimately the source
* of a SIGBUS
*/
static void *counting_thread(void *arg)
{
unsigned long page_nr = 0;
unsigned int seed = (unsigned int)arg + (unsigned int)time(NULL);
void *addr;
for (;;) {
page_nr = rand_r(&seed) % nr_pages;
addr = base + (page_nr * page_size);
__sync_fetch_and_add((unsigned long *)addr, 1);
}
return NULL;
}
/*
* There is a single zapping thread.
*
* When all pages have faulted in this thread will zap them
* away again with MADV_DONTNEED.
*/
static void *zapping_thread(void *arg)
{
for (;;) {
/* After all pages are faulted in this thread will zap the range again. */
if (__sync_fetch_and_add(&current_pages_in_region, 0) !=
nr_pages) {
usleep(100);
continue;
}
/* Zap away all the pages */
madvise(base, nr_pages * page_size, MADV_DONTNEED);
__sync_sub_and_fetch(&current_pages_in_region, nr_pages);
}
return NULL;
}
/*
* handle_fault just zero fills the page with the fault and then increments
* the number of faulted pages
**/
static int handle_fault(struct uffd_msg *msg)
{
struct uffdio_zeropage zp;
int ret;
assert(msg->event == UFFD_EVENT_PAGEFAULT);
zp.range.start =
msg->arg.pagefault.address & ~((unsigned long)page_size - 1);
zp.range.len = page_size;
zp.zeropage = 0;
ret = ioctl(uffd, UFFDIO_ZEROPAGE, &zp);
if (ret < 0) {
assert(zp.zeropage == -EEXIST);
return 0;
}
assert(zp.zeropage == page_size);
__sync_add_and_fetch(&current_pages_in_region, 1);
return 0;
}
/*
* Fault handling threads run one per cpu.
*
* These threads just handle the fault events generated by
* the counting threads.
*/
static void *uffd_thread(void *arg)
{
struct uffd_msg msg;
int ret;
for (;;) {
ret = read(uffd, &msg, sizeof(msg));
if (ret == sizeof(msg)) {
handle_fault(&msg);
} else if (ret < 0 && errno == EINTR) {
continue;
} else {
fprintf(stderr, "Unexpected return from read %d %d\n",
ret, errno);
exit(1);
}
}
return NULL;
}
static int test_userfaultfd_zapping()
{
unsigned long cpu;
pthread_t counting_threads[nr_cpus];
pthread_t zap_thread;
pthread_t uffd_threads[nr_cpus];
struct uffdio_register uffdio_register;
struct uffdio_api uffdio_api;
base = mmap(NULL, nr_pages * page_size, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
assert(base != MAP_FAILED);
/* Make sure the pages are already faulted in */
memset(base, 0, nr_pages * page_size);
__sync_add_and_fetch(&current_pages_in_region, nr_pages);
uffd = syscall(__NR_userfaultfd, 0);
assert(uffd != -1);
uffdio_api.api = UFFD_API;
uffdio_api.features = 0;
assert(ioctl(uffd, UFFDIO_API, &uffdio_api) != -1);
/* register */
uffdio_register.range.start = (unsigned long)base;
uffdio_register.range.len = nr_pages * page_size;
uffdio_register.mode = UFFDIO_REGISTER_MODE_MISSING;
assert(ioctl(uffd, UFFDIO_REGISTER, &uffdio_register) != -1);
for (cpu = 0; cpu < nr_cpus; cpu++) {
assert(pthread_create(&counting_threads[cpu], NULL,
counting_thread, (void *)cpu) == 0);
assert(pthread_create(&uffd_threads[cpu], NULL, uffd_thread,
(void *)NULL) == 0);
}
assert(pthread_create(&zap_thread, NULL,
zapping_thread, (void *)NULL) == 0);
/* Run until we cause a SIGBUS */
pthread_join(zap_thread, NULL);
}
int main(int argc, char **argv)
{
nr_cpus = sysconf(_SC_NPROCESSORS_ONLN);
page_size = sysconf(_SC_PAGESIZE);
nr_pages = 10 * nr_cpus;
printf("Running test with %d page size and %d cpus\n", page_size,
nr_cpus);
return test_userfaultfd_zapping();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment