Created
February 14, 2020 22:34
-
-
Save bgaff/3a8b2a890ae4771be22456e014c2e5aa to your computer and use it in GitHub Desktop.
Example of SIGBUS caused by race handling userfaults
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define _GNU_SOURCE | |
#include <stdio.h> | |
#include <errno.h> | |
#include <unistd.h> | |
#include <assert.h> | |
#include <string.h> | |
#include <stdlib.h> | |
#include <sys/types.h> | |
#include <sys/stat.h> | |
#include <fcntl.h> | |
#include <time.h> | |
#include <poll.h> | |
#include <sys/mman.h> | |
#include <sys/syscall.h> | |
#include <sys/ioctl.h> | |
#include <pthread.h> | |
#include <linux/userfaultfd.h> | |
static void *base; | |
static unsigned int nr_pages; | |
static unsigned int page_size; | |
static unsigned int nr_cpus; | |
static int current_pages_in_region; | |
int uffd; | |
/* | |
* Counting threads run one per cpu. | |
* | |
* These threads will spin doing nothing but adding one to the | |
* first unsigned long bytes of a randomly choosen page. This | |
* thread will be the source of faults and ultimately the source | |
* of a SIGBUS | |
*/ | |
static void *counting_thread(void *arg) | |
{ | |
unsigned long page_nr = 0; | |
unsigned int seed = (unsigned int)arg + (unsigned int)time(NULL); | |
void *addr; | |
for (;;) { | |
page_nr = rand_r(&seed) % nr_pages; | |
addr = base + (page_nr * page_size); | |
__sync_fetch_and_add((unsigned long *)addr, 1); | |
} | |
return NULL; | |
} | |
/* | |
* There is a single zapping thread. | |
* | |
* When all pages have faulted in this thread will zap them | |
* away again with MADV_DONTNEED. | |
*/ | |
static void *zapping_thread(void *arg) | |
{ | |
for (;;) { | |
/* After all pages are faulted in this thread will zap the range again. */ | |
if (__sync_fetch_and_add(¤t_pages_in_region, 0) != | |
nr_pages) { | |
usleep(100); | |
continue; | |
} | |
/* Zap away all the pages */ | |
madvise(base, nr_pages * page_size, MADV_DONTNEED); | |
__sync_sub_and_fetch(¤t_pages_in_region, nr_pages); | |
} | |
return NULL; | |
} | |
/* | |
* handle_fault just zero fills the page with the fault and then increments | |
* the number of faulted pages | |
**/ | |
static int handle_fault(struct uffd_msg *msg) | |
{ | |
struct uffdio_zeropage zp; | |
int ret; | |
assert(msg->event == UFFD_EVENT_PAGEFAULT); | |
zp.range.start = | |
msg->arg.pagefault.address & ~((unsigned long)page_size - 1); | |
zp.range.len = page_size; | |
zp.zeropage = 0; | |
ret = ioctl(uffd, UFFDIO_ZEROPAGE, &zp); | |
if (ret < 0) { | |
assert(zp.zeropage == -EEXIST); | |
return 0; | |
} | |
assert(zp.zeropage == page_size); | |
__sync_add_and_fetch(¤t_pages_in_region, 1); | |
return 0; | |
} | |
/* | |
* Fault handling threads run one per cpu. | |
* | |
* These threads just handle the fault events generated by | |
* the counting threads. | |
*/ | |
static void *uffd_thread(void *arg) | |
{ | |
struct uffd_msg msg; | |
int ret; | |
for (;;) { | |
ret = read(uffd, &msg, sizeof(msg)); | |
if (ret == sizeof(msg)) { | |
handle_fault(&msg); | |
} else if (ret < 0 && errno == EINTR) { | |
continue; | |
} else { | |
fprintf(stderr, "Unexpected return from read %d %d\n", | |
ret, errno); | |
exit(1); | |
} | |
} | |
return NULL; | |
} | |
static int test_userfaultfd_zapping() | |
{ | |
unsigned long cpu; | |
pthread_t counting_threads[nr_cpus]; | |
pthread_t zap_thread; | |
pthread_t uffd_threads[nr_cpus]; | |
struct uffdio_register uffdio_register; | |
struct uffdio_api uffdio_api; | |
base = mmap(NULL, nr_pages * page_size, PROT_READ | PROT_WRITE, | |
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); | |
assert(base != MAP_FAILED); | |
/* Make sure the pages are already faulted in */ | |
memset(base, 0, nr_pages * page_size); | |
__sync_add_and_fetch(¤t_pages_in_region, nr_pages); | |
uffd = syscall(__NR_userfaultfd, 0); | |
assert(uffd != -1); | |
uffdio_api.api = UFFD_API; | |
uffdio_api.features = 0; | |
assert(ioctl(uffd, UFFDIO_API, &uffdio_api) != -1); | |
/* register */ | |
uffdio_register.range.start = (unsigned long)base; | |
uffdio_register.range.len = nr_pages * page_size; | |
uffdio_register.mode = UFFDIO_REGISTER_MODE_MISSING; | |
assert(ioctl(uffd, UFFDIO_REGISTER, &uffdio_register) != -1); | |
for (cpu = 0; cpu < nr_cpus; cpu++) { | |
assert(pthread_create(&counting_threads[cpu], NULL, | |
counting_thread, (void *)cpu) == 0); | |
assert(pthread_create(&uffd_threads[cpu], NULL, uffd_thread, | |
(void *)NULL) == 0); | |
} | |
assert(pthread_create(&zap_thread, NULL, | |
zapping_thread, (void *)NULL) == 0); | |
/* Run until we cause a SIGBUS */ | |
pthread_join(zap_thread, NULL); | |
} | |
int main(int argc, char **argv) | |
{ | |
nr_cpus = sysconf(_SC_NPROCESSORS_ONLN); | |
page_size = sysconf(_SC_PAGESIZE); | |
nr_pages = 10 * nr_cpus; | |
printf("Running test with %d page size and %d cpus\n", page_size, | |
nr_cpus); | |
return test_userfaultfd_zapping(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment