Skip to content

Instantly share code, notes, and snippets.

@kprotty
Last active April 11, 2022 03:48
Show Gist options
  • Save kprotty/3042436aa55620d8ebcddf2bf25668bc to your computer and use it in GitHub Desktop.
Save kprotty/3042436aa55620d8ebcddf2bf25668bc to your computer and use it in GitHub Desktop.
// state: [value:u30, writers_parked:u1, readers_parked:u1]:u32
// if value == 0: RwLock is unlocked
// if value == max(u30): RwLock is writer locked
// else: RwLock has ${value} readers
unlocked = 0
readers_parked = 1 << 0
writers_parked = 1 << 1
value = 1 << 2
mask = ~(value - 1)
reader = value
writer = mask
type RwLock:
state: u32 = 0 // readers sleep on value, and woken up on writer unlock
epoch: u32 = 0 // writers sleep on counter, and woken up via increment
try_write():
return Atomic.cas(&state, unlocked, writer, Acquire) == None
write():
// fast path
s = Atomic.cas(&state, unlocked, writer, Acquire) orelse return
acquire_with = 0
loop:
// try to acqiure the Rwlock even if theres other parked threads.
while s & mask == unlocked:
new_s = s | writer | acquire_with
s = Atomic.cas(&state, s, new_s, Acquire) orelse return
// Make sure the parked bit is set before sleeping
if s & writers_parked == 0:
new_s = s | writers_parked
s = Atomic.cas(&state, s, new_s, Relaxed) orelse break :if
continue
// When we wake up, we must acquire with the parked bit set.
// This ensures that future write_unlock() will wake up the
// writers sleeping who didn't get to see the state change.
// Unfortunately ends up in an extra writer wake after last write_unlock()
// but this is better than having to wake all writers on write_unlock().
acquire_with = writers_parked
loop:
// Load the epoch that we'll be sleeping on
// before checking if we still need to sleep (Acquire barrier).
e = Atomic.load(&epoch, Acquire)
// Sleep only when we cant acquire the RwLock and the parked bit is set.
s = Atomic.load(&state, Relaxed)
if (s & mask == unlocked) or (s & writers_parked == 0):
break
Futex.wait(&epoch, e)
write_unlock():
// fast path
s = Atomic.cas(&state, writer, unlocked, Release) orelse return
// Get the parked bits, as there must be some if the cas failed above
p = s & (readers_parked | writers_parked)
assert(s & mask == writer)
assert(p != 0)
// If either readers or writers are waiting (but not both)
// try to unlock the RwLock while unsetting the waiting type parked bit.
// Failing this cas means that both types of threads are waiting
// and that the state is as its max value so we know other threads cant change it.
if p != (readers_parked | writers_parked):
if let Some(s) = Atomic.cas(&state, s, unlocked, Release):
assert(s == writer | readers_parked | writers_parked) // max state value
p = s & (readers_parked | writers_parked)
// Both type of threads waiting means state at max value so we can change with non-rmw.
// Choose to unset the readers_parked when unlocking instead of writers_parked.
// Due to there being a phony writers_parked after the last writer,
// choosing to wake the writers could not wake anything
// and leave the readers waiting while RwLock is unlocked until next write_unlock().
if p == (readers_parked | writers_parked):
Atomic.store(&state, writers_parked, Release)
p = readers_parked
// At this point we've unlocked the RwLock and unset whatever parked bit is in ${p}.
// We never wake both, so wake either readers here or writers below.
if p == readers_parked:
Futex.wake(&state, max(u32))
return
assert(p == writers_parked)
Atomic.fetch_add(&epoch, 1, Release)
Futex.wake(&epoch, 1)
try_read():
// Only acquire a reader if the count is unlocked (0)
// or wouldn't overflow into a writer-lock if a reader was added (writer - 1)
s = Atomic.load(&state, Relaxed)
while s & mask < (writer - 1):
s = Atomic.cas(&state, s, s + reader, Acquire) orelse return true
return false
read():
// fast path
s = Atomic.load(&state, Relaxed)
if s & mask < (writer - 1):
s = Atomic.cas(&state, s, s + reader, Acquire) orelse return
loop:
// Same thing as try_read()
while s & mask < (writer - 1):
s = Atomic.cas(&state, s, s + reader, Acquire) orelse return
// Panic if the reader count would overflow instead of sleeping.
if s & mask == (writer - 1):
unreachable("reader count would overflow into a writer")
// Make sure the park bit is set before sleeping.
if s & readers_parked == 0:
new_s = s | readers_parked
s = Atomic.cas(&state, s, new_s, Relaxed) orelse break :if
continue
// Readers sleep on the state directly as it will only
// change here if the writer is unlocked.
Futex.wait(&state, s | readers_parked)
s = Atomic.load(&state, Relaxed)
read_unlock():
// fast path
s = Atomic.fetch_sub(&state, reader, Release)
assert(s & mask >= reader)
assert(s & mask != writer)
// If we're the last reader and theres a writer waiting, then wake it up.
// Only wake it up if we can clear the writers_parked bit atomically.
// Failing to clear it means either a reader or writer acquire the RwLock,
// in which case, we should leave wake-up to them.
if s & (mask | writers_parked) == (reader | writers_parked):
_ = Atomic.cas(&state, writers_parked, unlocked, Relaxed) orelse:
Atomic.fetch_add(&epoch, 1, Release)
Futex.wake(&epoch, 1)
assume type Futex:
// Waits until either:
// ${ptr} != ${expect}
// signaled by wake()
// spurious wakeup
wait(ptr: &u32, expect: u32):
// Wake's at most ${max_wake} waiters on the same ${ptr}.
// ${ptr} is allowed to not point to valid memory.
wake(ptr: *const u32, max_wake: u32):
assume type Atomic(T):
// assumes Relaxed for failure Ordering
// returns None if success
// returns Some(T) if err with updated T
cas(ptr: &T, cmp: T, xchg: T, success: Ordering): Option(T)
fetch_add(ptr: &T, value: T, ordering: Ordering): T
fetch_sub(ptr: &T, value: T, ordering: Ordering): T
load(ptr: &T, ordering: Ordering): T
store(ptr: &T, value: T, ordering: Ordering)
rewrite syntax:
"$x:Option(T) orelse $y:tt*": (
match $x:
Some(X): X
None: $y
)
// state: [mask:uN-3, queue_locked:u1, queued:u1, locked:u1]:usize
// if queued:
// mask points to the top Waiter of the queue
// if queue_locked:
// queued bit is set, and the thread which set this bit is updating/unparking-from the queue
// if locked:
// the RwLock is owned by a writer if mask == 0
// the RwLock is owned by a reader if mask > 0
unlocked = 0
locked = 1 << 0
queued = 1 << 1
queue_locked = 1 << 2
reader = 1 << 3
mask = ~(reader - 1)
type Event:
state: u32 = 0
wait():
if SWAP(&state, 1, Acquire) == 0:
s = 1
while s == 1:
WAIT(&state, 1)
s = LOAD(&state, Acquire)
set():
if SWAP(&state, 2, Release) == 1:
WAKE(&state, 1)
@align(8)
type Waiter:
prev: ?*Waiter
next: ?*Waiter
tail: ?*Waiter
is_writer: bool
readers: usize
event: Event
type SRWLock:
state: usize = unlocked
try_write():
return CAS(&state, unlocked, locked, Acquire) == null
write():
s = CAS(&state, unlocked, locked, Acquire) orelse return
loop:
// similar to try_write()
while s & locked == 0:
s = CAS(&state, s, s | locked, Acquire) orelse return
s = park(s, is_writer: true)
unlock_write():
s = CAS(&state, locked, unlocked, Release) orelse return
unlock_and_unpark(s)
try_read():
// Can't acquire reader if queued, as mask is now the head *Waiter
// Can't acquire reader if write-locked (locked bit with zero mask).
// Returns false on reader count overflow since try_read().
s = LOAD(&state, Relaxed)
while (s & queued == 0) and (s & (locked | mask) != locked):
new_s = add(s, reader) catch return false
s = CAS(&state, s, new_s | locked, Acquire) orelse return true
return false
read():
s = CAS(&state, unlocked, locked, Acquire) orelse return
loop:
// same as try_read(), but panics on reader count overflow
while (s & queued == 0) and (s & (locked | mask) != locked):
new_s = add(s, reader) catch panic("reader count overflowed")
s = CAS(&state, s, new_s | locked, Acquire) orelse return
s = park(s, is_writer: false)
unlock_read():
// fast path
s = CAS(&state, reader|locked, unlocked, Release) orelse return
// If there's no queue, the mask contains the reader count
while s & queued == 0:
assert(s & locked != 0)
assert(s & mask >= reader)
// Remove one reader, and the last reader unsets the locked bit for writers.
new_s = s - reader
if new_s & mask == 0:
new_s &= ~locked
s = CAS(&state, s, new_s, Release) orelse return
// There's queued threads now, so the reader count was moved to the tail.
// Acquire barrier to ensure that the head writes published in park() happen before we access the head.
assert(s & locked != 0)
assert(s & queued != 0)
FENCE(Acquire)
// Find the tail (note: we aren't the queue_locked holder)
head = *Waiter(s & mask)
tail = loop:
if head.tail |t| break t
head = head.next orelse unreachable
// Remove the reader from the tail
assert(tail.is_writer)
readers = SUB(&tail.readers, reader, Release)
// Other readers exit reader
assert(readers >= reader)
if readers > reader:
return
// Last reader removed the locked bit and unparks one thread (the writer).
s = LOAD(&state, Relaxed)
unlock_and_unpark(s)
unlock_and_unpark(s):
loop:
assert(s & locked != 0)
assert(s & queued != 0)
// Unset the locked bit and try to grab the queue_locked bit
// since we know that there's a queue there to wake from.
new_s = s & ~locked
new_s |= queue_locked
s = CAS(&state, s, new_s, Release) orelse:
if s & queue_locked == 0:
unpark(new_s)
return
park(s, is_writer):
w = Waiter{}
w.prev = null
w.is_writer = is_writer
// Prepare our waiter as the new head of the queue.
//
// If we're the first, then set the tail to ourselves so get_and_link_queue() will eventually find it.
// Also, if we're a writer, we should also note down the current readers since we're repurposing the mask.
//
// If we're not the first, we know the mask points to the previous *Waiter head.
// Also, we should try to acquire the queued_locked bit in order to find the tail and cache it at the head.
new_s = (s & ~mask) | usize(&w) | queued
if s & queued == 0:
switch (is_writer):
true: w.readers = s & mask
false: assert(state == locked)
w.tail = &w
w.next = null
else:
w.next = ?*Waiter(s & mask)
new_s |= queue_locked
// Release barrier ensures those which observe the head *Waiter see our w.* writes above.
if CAS(&state, s, new_s, Release) |updated|:
return updated
// We grabbed the queue_locked bit, so update the queue and release the queue_locked bit.
if s & (queued | queue_locked) == queued:
link_queue_or_unpark(new_s)
w.event.wait()
return LOAD(&state, Relaxed)
link_queue_or_unpark(s):
loop:
// If the state became unlocked during this, we must do a wakeup in the locked holder's place.
assert(s & queue_locked != 0)
if s & locked == 0:
return unpark(s)
// Update the queue and cache the tail at the head (Acquire barrier)
assert(s & queued != 0)
_ = get_and_link_queue(s)
// Release the queued_locked bit.
// Release barrier ensures the writes in get_and_link_queue() happen before the next queue_locked holder.
new_s = s & ~queue_locked
s = CAS(&state, s, new_s, Release) orelse return
unpark(s):
tail = loop:
// Get the head and tail, caching the tail at the head for future lookups.
assert(s & queue_locked != 0)
assert(s & queued != 0)
head, tail = get_and_link_queue(s)
// If the lock is held, then release the queued_locked bit
// for the lock holder to end up doing the unpark() instead.
// Release barrier ensures the updates in get_and_link_queue() happen-before the next queue_locked holder
if s & locked != 0:
s = CAS(&state, s, s & ~queue_locked, Release) orelse return
continue
// Find the last in [head, last, tail]:queue for which we wake Waiters last to tail.
// Wanted to experiment with different last selection policies, but this is the one used by windows SRWLOCK.
wake_policy = .first_writer_or_all
last = select_last_waiter(head, tail, wake_policy)
// If there's a Waiter "after" the one we selected to wake-until,
// then we can simply update the head's tail to point to that and release the queue_locked bit.
if last.prev |new_tail|:
head.tail = new_tail
AND(&state, ~queue_locked, Release)
last.prev = null
break tail
// If there's no Waiter "after" the one we selected to wake-until,
// then we're waking all of them so we need to clear the mask and the queued bits while releasing the queue_locked bit.
// Release barrier ensures the updates we did above happen-before the next queue_locked holder.
new_s = s & ~(mask | queued | queue_locked)
s = CAS(&state, s, new_s, Release) orelse break tail
loop:
prev = tail.prev
tail.event.wake()
tail = prev orelse break
select_last_waiter(head, tail, wake_policy):
switch (wake_policy):
// This is the policy used by windows SRWLock:
// If the tail is a writer, it wakes only that one.
// If not, it wakes everything with the optimistic assumption that they're all readers.
.first_writer_or_all:
if tail.is_writer: return tail
return head
// This is the policy used by parking_lot RwLock:
// Wakes all readers + the first writer it finds while walking the queue.
.all_readers_and_first_writer:
last = tail
while not last.is_writer:
last = last.prev orelse break
return last
get_and_link_queue(s):
// Acquire barrier to ensure the head writes in park() happen-before we access it here.
assert(s & queue_locked != 0)
assert(s & queued != 0)
FENCE(Acquire)
// Find the head and tail, checking if the tail is cached at the head first
head = *Waiter(s & mask)
tail = head.tail orelse blk:
// Look for the tail, setting the .prev links in the process
current = head
loop:
next = current.next orelse unreachable
next.prev = current
current = next
// Found the tail, cache it at the head
if current.tail |tail|:
head.tail = tail
break :blk tail
return head, tail
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment