Skip to content

Instantly share code, notes, and snippets.

@48ca
Created March 6, 2024 00:03

Revisions

  1. 48ca created this gist Mar 6, 2024.
    124 changes: 124 additions & 0 deletions uffdio-continue-wmb-test.c
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,124 @@
    #define _GNU_SOURCE

    #include <sys/mman.h>
    #include <sys/ioctl.h>
    #include <fcntl.h> /* Definition of O_* constants */
    #include <sys/syscall.h> /* Definition of SYS_* constants */
    #include <linux/userfaultfd.h> /* Definition of UFFD_* constants */
    #include <unistd.h>
    #include <pthread.h>
    #include <errno.h>
    #include <stdio.h>
    #include <linux/memfd.h>
    #include <signal.h>

    #define FAIL(err, msg) \
    do { \
    perror(msg); \
    return err; \
    } while(0)

    #define PAGE_SIZE 4096L

    const int num_threads = 128;
    const size_t length = PAGE_SIZE * 512 * 512 * 8;
    char *primary_map;

    void *thread_func(void *)
    {
    long failed = 0;
    size_t off = 0;
    for (off = 0; off < length; off += PAGE_SIZE) {
    char a = primary_map[off];
    if (a != 'A') {
    printf("Observed %c instead of A", a);
    failed = 1;
    }
    }
    return (void *)failed;
    }

    void sighand() {
    }

    int main()
    {
    // open shmem memfd
    // mmap twice, same file
    // open uffd with UFFD_FEATURE_SIGBUS, ignore sigbus
    // spawn thread(s), continuously check userfault region
    // in main thread, do a small write, then immediately UFFDIO_CONTINUE, move on to the next page.

    int fd = memfd_create("test", 0);
    if (fd < 0)
    FAIL(fd, "memfd_create failed");

    int ret = ftruncate(fd, length);
    if (ret < 0)
    FAIL(ret, "ftruncate failed");

    ret = fallocate(fd, FALLOC_FL_KEEP_SIZE, 0, length);
    if (ret < 0)
    FAIL(ret, "fallocate failed");

    primary_map = (char *)mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
    if (primary_map == MAP_FAILED)
    FAIL(-1, "primary map failed");

    char *alias_map = (char *)mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
    if (alias_map == MAP_FAILED)
    FAIL(-1, "alias map failed");

    int uffd = syscall(__NR_userfaultfd, UFFD_USER_MODE_ONLY);
    if (uffd < 0)
    FAIL(uffd, "uffd creation failed");

    struct uffdio_api uffd_api = {
    .api = UFFD_API,
    .features = UFFD_FEATURE_SIGBUS | UFFD_FEATURE_MINOR_SHMEM,
    .ioctls = 0,
    };
    ret = ioctl(uffd, UFFDIO_API, &uffd_api);
    if (ret < 0)
    FAIL(ret, "UFFDIO_API failed");

    struct uffdio_register uffd_register = {
    .range = {.start = (unsigned long long)primary_map,
    .len = length},
    .mode = UFFDIO_REGISTER_MODE_MINOR,
    .ioctls = 0,
    };
    ret = ioctl(uffd, UFFDIO_REGISTER, &uffd_register);
    if (ret < 0)
    FAIL(ret, "UFFDIO_REGISTER failed");

    if (signal(SIGBUS, sighand) == SIG_ERR)
    FAIL(-1, "setting signal handler failed");

    pthread_t thds[num_threads];
    for (int i = 0; i < num_threads; ++i)
    pthread_create(&thds[i], NULL, &thread_func, NULL);

    size_t off;
    struct uffdio_continue uffd_continue;
    uffd_continue.range.len = PAGE_SIZE;
    uffd_continue.mode = UFFDIO_CONTINUE_MODE_DONTWAKE;
    for (off = 0; off < length; off += PAGE_SIZE) {
    uffd_continue.range.start = (unsigned long long)primary_map + off;
    alias_map[off] = 'A';
    ret = ioctl(uffd, UFFDIO_CONTINUE, &uffd_continue);
    if (ret < 0)
    FAIL(ret, "UFFDIO_CONTINUE failed");
    }

    long num_failed = 0;
    for (int i = 0; i < num_threads; ++i) {
    long failed;
    pthread_join(thds[i], (void *)&failed);
    if (failed && !num_failed)
    printf("Actually observed memory ordering issues.");
    num_failed += failed;
    }

    return num_failed ? 1 : 0;
    }