Last active
March 17, 2025 22:34
Robust Lock?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Robust Lock: if a "thread" exits while holding a lock, it is unlocked | |
// linux $ cc -pthread example.c | |
// w64dk $ cc example.c -lntdll | |
// Ref: https://old.reddit.com/r/C_Programming/comments/1jd82ux | |
// Ref: https://github.com/cozis/timestamp_lock | |
// This is free and unencumbered software released into the public domain. | |
#include <assert.h> | |
#include <stddef.h> | |
#include <stdint.h> | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
typedef int32_t b32; | |
typedef int32_t i32; | |
typedef int64_t i64; | |
typedef ptrdiff_t iz; | |
typedef size_t uz; | |
typedef char byte; | |
#define new(a, n, t) (t *)alloc(a, n, sizeof(t), _Alignof(t)) | |
#define lenof(a) (iz)(sizeof(a) / sizeof(*(a))) | |
typedef struct { byte *beg, *end; } Arena; | |
static byte *alloc(Arena *a, iz count, iz size, iz align) | |
{ | |
iz pad = (uz)a->end & (align - 1); | |
assert(count < (a->end - a->beg - pad)/size); | |
return memset(a->end -= pad + count*size, 0, count*size); | |
} | |
// Platform API | |
static b32 cas(uz *, uz *, uz); | |
static uz load(uz *); | |
static void wait(uz *, uz, i32); | |
static uz newthread(uz (*)(void *), void *); | |
static void store(uz *, uz); | |
static void wake(uz *); | |
typedef struct { | |
void (*func)(uz, void *); | |
void *arg; | |
} Thunk; | |
typedef struct { | |
uz handle; | |
uz thunk; | |
uz taskid; | |
} Thread; | |
enum { POOL_EXP = 4 }; | |
typedef struct { | |
Thread threads[1<<POOL_EXP]; | |
} Pool; | |
static uz worker(void *arg) | |
{ | |
Thread *thread = arg; | |
for (;;) { | |
Thunk *t = (Thunk *)load(&thread->thunk); | |
while (!t) { | |
// FIXME: Why does this sometimes wake with a null thunk? | |
wait(&thread->thunk, 0, 0); | |
t = (Thunk *)load(&thread->thunk); | |
} | |
uz taskid = thread->taskid; | |
t->func(taskid, t->arg); | |
store(&thread->taskid, taskid+(1<<POOL_EXP)); | |
store(&thread->thunk, 0); | |
} | |
return 0; | |
} | |
static Pool *newpool(Arena *a) | |
{ | |
Pool *r = new(a, 1, Pool); | |
for (int i = 0; i < lenof(r->threads); i++) { | |
r->threads[i].taskid = i + 1; | |
r->threads[i].handle = newthread(worker, r->threads+i); | |
} | |
return r; | |
} | |
static void enqueue(Pool *pool, Thunk *t) | |
{ | |
for (;;) { | |
// TODO: wait for non-full | |
for (i32 i = 0; i < lenof(pool->threads); i++) { | |
if (cas(&pool->threads[i].thunk, &(uz){0}, (uz)t)) { | |
wake(&pool->threads[i].thunk); | |
return; | |
} | |
} | |
} | |
} | |
static b32 isalive(Pool *pool, uz id) | |
{ | |
uz mask = (1<<POOL_EXP) - 1; | |
i32 i = (i32)(id&mask) - 1; | |
return load(&pool->threads[i].taskid) == id; | |
} | |
static void lock(uz *l, Pool *pool, uz id) | |
{ | |
for (;;) { | |
uz who = 0; | |
if (cas(l, &who, id)) { | |
break; | |
} | |
if (!isalive(pool, who)) { | |
if (cas(l, &who, id)) { | |
break; // owner is dead, steal it | |
} | |
} | |
wait(l, who, 200); // wait up to 200ms | |
} | |
} | |
static void unlock(uz *l) | |
{ | |
// TODO: only wake if there are waiters | |
store(l, 0); | |
wake(l); | |
} | |
typedef struct { | |
uz *lock; | |
Pool *pool; | |
uz done; | |
b32 fail; | |
i32 value; | |
i32 result; | |
} Job; | |
static void compute(uz id, void *arg) | |
{ | |
Job *job = arg; | |
lock(job->lock, job->pool, id); | |
job->result = -job->value; | |
if (job->fail) { | |
// fail to unlock the lock | |
} else { | |
unlock(job->lock); | |
} | |
store(&job->done, 1); | |
wake(&job->done); | |
} | |
int main(void) | |
{ | |
iz cap = (iz)1<<24; | |
byte *mem = malloc(cap); | |
Arena a = {mem, mem+cap}; | |
Pool *pool = newpool(&a); | |
uz lock = 0; | |
i32 njobs = 100; | |
Thunk *thunks = new(&a, njobs, Thunk); | |
Job *jobs = new(&a, njobs, Job); | |
for (i32 i = 0; i < njobs; i++) { | |
jobs[i].lock = &lock; | |
jobs[i].pool = pool; | |
jobs[i].fail = !(i % 4); | |
jobs[i].value = i + 1; | |
thunks[i].func = compute; | |
thunks[i].arg = jobs + i; | |
enqueue(pool, thunks+i); | |
} | |
for (i32 i = 0; i < njobs; i++) { | |
wait(&jobs[i].done, 0, 0); | |
printf("%d\n", jobs[i].result); | |
} | |
} | |
#if __GNUC__ | |
static void store(uz *p, uz v) | |
{ | |
__atomic_store_n(p, v, __ATOMIC_RELEASE); | |
} | |
static uz load(uz *p) | |
{ | |
return __atomic_load_n(p, __ATOMIC_ACQUIRE); | |
} | |
static b32 cas(uz *p, uz *old, uz new) | |
{ | |
i32 pass = __ATOMIC_RELEASE; | |
i32 fail = __ATOMIC_ACQUIRE; | |
return __atomic_compare_exchange_n(p, old, new, 0, pass, fail); | |
} | |
#endif | |
#if _WIN32 | |
#define W32(r) __declspec(dllimport) r __stdcall | |
W32(uz) CreateThread(uz, iz, uz, void *, i32, uz); | |
W32(i32) RtlWaitOnAddress(void *, void *, uz, i64 *); | |
W32(i32) RtlWakeAddressSingle(void *); | |
static uz newthread(uz (*func)(void *), void *arg) | |
{ | |
return CreateThread(0, 0, (uz)func, arg, 0, 0); | |
} | |
static void wait(uz *p, uz v, i32 timeout_ms) | |
{ | |
i64 timeout = (i64)timeout_ms * -10000; | |
RtlWaitOnAddress(p, &v, sizeof(*p), timeout_ms ? &timeout : 0); | |
} | |
static void wake(uz *p) | |
{ | |
RtlWakeAddressSingle(p); | |
} | |
#elif __linux | |
#include <linux/futex.h> | |
#include <pthread.h> | |
#include <sys/syscall.h> | |
# include <unistd.h> | |
static uz newthread(uz (*func)(void *), void *arg) | |
{ | |
pthread_t r; | |
pthread_create(&r, 0, (void *(*)(void *))(uz)func, arg); | |
return r; | |
} | |
static void wait(uz *p, uz v, i32 timeout_ms) | |
{ | |
struct timespec ts = {timeout_ms/1000, timeout_ms*1000000}; | |
syscall(SYS_futex, (i32 *)p, FUTEX_WAIT, (i32)v, timeout_ms?&ts:0, 0, 0); | |
} | |
static void wake(uz *p) | |
{ | |
syscall(SYS_futex, (i32 *)p, FUTEX_WAKE, 1, 0, 0, 0); | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment