Skip to content

Instantly share code, notes, and snippets.

@skeeto
Last active March 17, 2025 22:34
Robust Lock?
// Robust Lock: if a "thread" exits while holding a lock, it is unlocked
// linux $ cc -pthread example.c
// w64dk $ cc example.c -lntdll
// Ref: https://old.reddit.com/r/C_Programming/comments/1jd82ux
// Ref: https://github.com/cozis/timestamp_lock
// This is free and unencumbered software released into the public domain.
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
typedef int32_t b32;
typedef int32_t i32;
typedef int64_t i64;
typedef ptrdiff_t iz;
typedef size_t uz;
typedef char byte;
#define new(a, n, t) (t *)alloc(a, n, sizeof(t), _Alignof(t))
#define lenof(a) (iz)(sizeof(a) / sizeof(*(a)))
typedef struct { byte *beg, *end; } Arena;
static byte *alloc(Arena *a, iz count, iz size, iz align)
{
iz pad = (uz)a->end & (align - 1);
assert(count < (a->end - a->beg - pad)/size);
return memset(a->end -= pad + count*size, 0, count*size);
}
// Platform API
static b32 cas(uz *, uz *, uz);
static uz load(uz *);
static void wait(uz *, uz, i32);
static uz newthread(uz (*)(void *), void *);
static void store(uz *, uz);
static void wake(uz *);
typedef struct {
void (*func)(uz, void *);
void *arg;
} Thunk;
typedef struct {
uz handle;
uz thunk;
uz taskid;
} Thread;
enum { POOL_EXP = 4 };
typedef struct {
Thread threads[1<<POOL_EXP];
} Pool;
static uz worker(void *arg)
{
Thread *thread = arg;
for (;;) {
Thunk *t = (Thunk *)load(&thread->thunk);
while (!t) {
// FIXME: Why does this sometimes wake with a null thunk?
wait(&thread->thunk, 0, 0);
t = (Thunk *)load(&thread->thunk);
}
uz taskid = thread->taskid;
t->func(taskid, t->arg);
store(&thread->taskid, taskid+(1<<POOL_EXP));
store(&thread->thunk, 0);
}
return 0;
}
static Pool *newpool(Arena *a)
{
Pool *r = new(a, 1, Pool);
for (int i = 0; i < lenof(r->threads); i++) {
r->threads[i].taskid = i + 1;
r->threads[i].handle = newthread(worker, r->threads+i);
}
return r;
}
static void enqueue(Pool *pool, Thunk *t)
{
for (;;) {
// TODO: wait for non-full
for (i32 i = 0; i < lenof(pool->threads); i++) {
if (cas(&pool->threads[i].thunk, &(uz){0}, (uz)t)) {
wake(&pool->threads[i].thunk);
return;
}
}
}
}
static b32 isalive(Pool *pool, uz id)
{
uz mask = (1<<POOL_EXP) - 1;
i32 i = (i32)(id&mask) - 1;
return load(&pool->threads[i].taskid) == id;
}
static void lock(uz *l, Pool *pool, uz id)
{
for (;;) {
uz who = 0;
if (cas(l, &who, id)) {
break;
}
if (!isalive(pool, who)) {
if (cas(l, &who, id)) {
break; // owner is dead, steal it
}
}
wait(l, who, 200); // wait up to 200ms
}
}
static void unlock(uz *l)
{
// TODO: only wake if there are waiters
store(l, 0);
wake(l);
}
typedef struct {
uz *lock;
Pool *pool;
uz done;
b32 fail;
i32 value;
i32 result;
} Job;
static void compute(uz id, void *arg)
{
Job *job = arg;
lock(job->lock, job->pool, id);
job->result = -job->value;
if (job->fail) {
// fail to unlock the lock
} else {
unlock(job->lock);
}
store(&job->done, 1);
wake(&job->done);
}
int main(void)
{
iz cap = (iz)1<<24;
byte *mem = malloc(cap);
Arena a = {mem, mem+cap};
Pool *pool = newpool(&a);
uz lock = 0;
i32 njobs = 100;
Thunk *thunks = new(&a, njobs, Thunk);
Job *jobs = new(&a, njobs, Job);
for (i32 i = 0; i < njobs; i++) {
jobs[i].lock = &lock;
jobs[i].pool = pool;
jobs[i].fail = !(i % 4);
jobs[i].value = i + 1;
thunks[i].func = compute;
thunks[i].arg = jobs + i;
enqueue(pool, thunks+i);
}
for (i32 i = 0; i < njobs; i++) {
wait(&jobs[i].done, 0, 0);
printf("%d\n", jobs[i].result);
}
}
#if __GNUC__
static void store(uz *p, uz v)
{
__atomic_store_n(p, v, __ATOMIC_RELEASE);
}
static uz load(uz *p)
{
return __atomic_load_n(p, __ATOMIC_ACQUIRE);
}
static b32 cas(uz *p, uz *old, uz new)
{
i32 pass = __ATOMIC_RELEASE;
i32 fail = __ATOMIC_ACQUIRE;
return __atomic_compare_exchange_n(p, old, new, 0, pass, fail);
}
#endif
#if _WIN32
#define W32(r) __declspec(dllimport) r __stdcall
W32(uz) CreateThread(uz, iz, uz, void *, i32, uz);
W32(i32) RtlWaitOnAddress(void *, void *, uz, i64 *);
W32(i32) RtlWakeAddressSingle(void *);
static uz newthread(uz (*func)(void *), void *arg)
{
return CreateThread(0, 0, (uz)func, arg, 0, 0);
}
static void wait(uz *p, uz v, i32 timeout_ms)
{
i64 timeout = (i64)timeout_ms * -10000;
RtlWaitOnAddress(p, &v, sizeof(*p), timeout_ms ? &timeout : 0);
}
static void wake(uz *p)
{
RtlWakeAddressSingle(p);
}
#elif __linux
#include <linux/futex.h>
#include <pthread.h>
#include <sys/syscall.h>
# include <unistd.h>
static uz newthread(uz (*func)(void *), void *arg)
{
pthread_t r;
pthread_create(&r, 0, (void *(*)(void *))(uz)func, arg);
return r;
}
static void wait(uz *p, uz v, i32 timeout_ms)
{
struct timespec ts = {timeout_ms/1000, timeout_ms*1000000};
syscall(SYS_futex, (i32 *)p, FUTEX_WAIT, (i32)v, timeout_ms?&ts:0, 0, 0);
}
static void wake(uz *p)
{
syscall(SYS_futex, (i32 *)p, FUTEX_WAKE, 1, 0, 0, 0);
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment