Skip to content

Instantly share code, notes, and snippets.

@kprotty
Last active March 20, 2024 16:11
Show Gist options
  • Save kprotty/04dc190b1081d20e4ab3e9d7767cf4ab to your computer and use it in GitHub Desktop.
Save kprotty/04dc190b1081d20e4ab3e9d7767cf4ab to your computer and use it in GitHub Desktop.
mod futex {
use std::{sync::atomic::AtomicU32, time::Duration};
pub fn wait(_ptr: &AtomicU32, _cmp: u32, _timeout: Option<Duration>) -> bool {
unimplemented!("TODO")
}
pub fn wake(_ptr: *const AtomicU32, _max_wake: u32) {
unimplemented!("TODO")
}
}
mod spin {
#[derive(Default)]
pub struct SpinWait {
counter: usize,
}
impl SpinWait {
pub fn reset(&mut self) {
self.counter = 0;
}
pub fn yield_now(&mut self) -> bool {
self.counter < 100 && {
std::hint::spin_loop();
self.counter += 1;
true
}
}
pub fn force_yield(&mut self) {
self.counter = (self.counter + 1).min(4);
for _ in 0..(1 << self.counter) {
std::hint::spin_loop();
}
}
}
}
pub mod mutex {
use super::{futex, spin::SpinWait};
use std::sync::atomic::{AtomicU32, Ordering};
const UNLOCKED: u32 = 0;
const LOCKED: u32 = 1;
const CONTENDED: u32 = 2;
pub struct RawMutex {
state: AtomicU32,
}
impl RawMutex {
pub const fn new() -> Self {
Self {
state: AtomicU32::new(UNLOCKED),
}
}
#[inline]
pub fn try_lock(&self) -> bool {
self.state
.compare_exchange(UNLOCKED, LOCKED, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
#[inline]
pub fn lock(&self) {
if let Err(state) = self.state.compare_exchange_weak(
UNLOCKED,
LOCKED,
Ordering::Acquire,
Ordering::Relaxed,
) {
self.lock_slow(state);
}
}
#[cold]
fn lock_slow(&self, mut state: u32) {
let mut spin = SpinWait::default();
loop {
if state == UNLOCKED {
match self.state.compare_exchange(
state,
LOCKED,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(e) => state = e,
}
}
if state == LOCKED && spin.yield_now() {
state = self.state.load(Ordering::Relaxed);
continue;
}
break;
}
if state == CONTENDED {
futex::wait(&self.state, CONTENDED, None);
}
while self.state.swap(CONTENDED, Ordering::Acquire) != UNLOCKED {
futex::wait(&self.state, CONTENDED, None);
}
}
#[inline]
pub unsafe fn unlock(&self) {
if self.state.swap(UNLOCKED, Ordering::Release) == CONTENDED {
self.unlock_slow();
}
}
#[cold]
fn unlock_slow(&self) {
futex::wake(&self.state, 1);
}
}
}
pub mod condvar {
use super::{futex, mutex::RawMutex};
use std::{
sync::atomic::{AtomicU32, Ordering},
time::Duration,
};
const WAITER_SHIFT: u32 = 0;
const SIGNAL_SHIFT: u32 = 16;
pub struct RawCondvar {
state: AtomicU32,
epoch: AtomicU32,
}
impl RawCondvar {
pub const fn new() -> Self {
Self {
state: AtomicU32::new(0),
epoch: AtomicU32::new(0),
}
}
pub unsafe fn wait(&self, raw_mutex: &RawMutex, timeout: Option<Duration>) -> bool {
let state = self.state.fetch_add(1 << WAITER_SHIFT, Ordering::Relaxed);
assert_ne!((state >> WAITER_SHIFT) as u16, u16::MAX, "too many waiters");
let epoch = self.epoch.load(Ordering::Relaxed);
raw_mutex.unlock();
let wait_result = futex::wait(&self.epoch, epoch, timeout);
raw_mutex.lock();
let _ = self
.state
.fetch_update(Ordering::Acquire, Ordering::Relaxed, |state| {
let mut waiters = (state >> WAITER_SHIFT) as u16;
waiters = waiters.checked_sub(1).expect("invalid waiter count");
let mut signals = (state >> SIGNAL_SHIFT) as u16;
signals = signals.saturating_sub(1);
let mut new_state = (waiters as u32) << WAITER_SHIFT;
new_state |= (signals as u32) << SIGNAL_SHIFT;
Some(new_state)
});
wait_result
}
pub fn notify_one(&self) {
self.wake(1)
}
pub fn notify_all(&self) {
self.wake(u16::MAX)
}
fn wake(&self, max_wake: u16) {
let _ = self
.state
.fetch_update(Ordering::Release, Ordering::Relaxed, |state| {
let waiters = (state >> WAITER_SHIFT) as u16;
let signals = (state >> SIGNAL_SHIFT) as u16;
match (waiters - signals).min(max_wake) {
0 => None,
to_wake => Some(state + ((to_wake as u32) << SIGNAL_SHIFT)),
}
})
.map(|_| {
self.epoch.fetch_add(1, Ordering::Release);
futex::wake(&self.epoch, max_wake as u32);
});
}
}
}
pub mod once {
use super::{futex, spin::SpinWait};
use std::sync::atomic::{AtomicU32, Ordering};
const UNINIT: u32 = 0;
const CALLING: u32 = 1;
const WAITING: u32 = 2;
const CALLED: u32 = 3;
pub struct RawOnce {
state: AtomicU32,
}
impl RawOnce {
pub const fn new() -> Self {
Self {
state: AtomicU32::new(UNINIT),
}
}
#[inline]
pub fn call_once(&self, f: impl FnOnce()) {
match self.state.load(Ordering::Acquire) {
CALLED => {}
state => self.call_once_slow(state, f),
}
}
#[cold]
fn call_once_slow(&self, mut state: u32, f: impl FnOnce()) {
if state == UNINIT {
match self.state.compare_exchange(
UNINIT,
CALLING,
Ordering::Acquire,
Ordering::Acquire,
) {
Err(e) => state = e,
Ok(_) => {
f();
return match self.state.swap(CALLED, Ordering::Release) {
WAITING => futex::wake(&self.state, u32::MAX),
_ => {}
};
}
}
}
let mut spin = SpinWait::default();
while state == CALLING && spin.yield_now() {
state = self.state.load(Ordering::Acquire);
}
if state == CALLING {
state = match self.state.compare_exchange(
CALLING,
WAITING,
Ordering::Acquire,
Ordering::Acquire,
) {
Ok(_) => WAITING,
Err(e) => e,
};
}
while state == WAITING {
futex::wait(&self.state, WAITING, None);
state = self.state.load(Ordering::Acquire);
}
}
}
}
pub mod rwlock {
use super::{futex, spin::SpinWait};
use std::sync::atomic::{AtomicU32, Ordering};
const UNLOCKED: u32 = 0;
const READERS_PARKED: u32 = 1;
const WRITERS_PARKED: u32 = 2;
const VALUE: u32 = 4;
const MASK: u32 = !(VALUE - 1);
const READER: u32 = VALUE;
const WRITER: u32 = MASK;
pub struct RawRwLock {
state: AtomicU32,
epoch: AtomicU32,
}
impl RawRwLock {
pub const fn new() -> Self {
Self {
state: AtomicU32::new(UNLOCKED),
epoch: AtomicU32::new(0),
}
}
#[inline]
pub fn try_write(&self) -> bool {
self.state
.compare_exchange(UNLOCKED, WRITER, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
#[inline]
pub fn write(&self) {
if let Err(state) = self.state.compare_exchange_weak(
UNLOCKED,
WRITER,
Ordering::Acquire,
Ordering::Relaxed,
) {
self.write_slow(state);
}
}
#[cold]
fn write_slow(&self, mut state: u32) {
let mut did_wait = false;
let mut spin = SpinWait::default();
loop {
while state & MASK == UNLOCKED {
let mut new_state = state | WRITER;
if did_wait {
new_state |= WRITERS_PARKED;
}
match self.state.compare_exchange_weak(
state,
new_state,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(e) => state = e,
}
}
if state & WRITERS_PARKED == 0 {
if spin.yield_now() {
state = self.state.load(Ordering::Relaxed);
continue;
}
if let Err(e) = self.state.compare_exchange_weak(
state,
state | WRITERS_PARKED,
Ordering::Relaxed,
Ordering::Relaxed,
) {
state = e;
continue;
}
}
loop {
let epoch = self.epoch.load(Ordering::Acquire);
state = self.state.load(Ordering::Relaxed);
if (state & MASK == UNLOCKED) || (state & WRITERS_PARKED == 0) {
spin.reset();
break;
}
futex::wait(&self.epoch, epoch, None);
did_wait = true;
}
}
}
#[inline]
pub unsafe fn unlock_write(&self) {
let state = self.state.swap(UNLOCKED, Ordering::Release);
debug_assert_eq!(state & MASK, WRITER);
if state & (READERS_PARKED | WRITERS_PARKED) != 0 {
self.unlock_write_slow(state);
}
}
#[cold]
fn unlock_write_slow(&self, state: u32) {
if state & READERS_PARKED != 0 {
futex::wake(&self.state, u32::MAX);
}
if state & WRITERS_PARKED != 0 {
self.epoch.fetch_add(1, Ordering::Release);
futex::wake(&self.epoch, 1);
}
}
#[inline]
pub fn try_read(&self) -> bool {
match self.state.compare_exchange_weak(
UNLOCKED,
READER,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => true,
Err(state) => self.try_read_slow(state),
}
}
#[cold]
fn try_read_slow(&self, mut state: u32) -> bool {
while state & MASK < (WRITER - 1) {
match self.state.compare_exchange_weak(
state,
state + READER,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(e) => state = e,
}
}
false
}
#[inline]
pub fn read(&self) {
if let Err(state) = self.state.compare_exchange_weak(
UNLOCKED,
READER,
Ordering::Acquire,
Ordering::Relaxed,
) {
self.read_slow(state);
}
}
#[cold]
fn read_slow(&self, mut state: u32) {
let mut spin = SpinWait::default();
loop {
let mut backoff = SpinWait::default();
while (state & MASK < WRITER) && (state & WRITERS_PARKED == 0) {
let new_state = state + READER;
if new_state & MASK == WRITER {
unreachable!("too many readers");
}
if let Ok(_) = self.state.compare_exchange_weak(
state,
new_state,
Ordering::Acquire,
Ordering::Relaxed,
) {
return;
}
backoff.force_yield();
state = self.state.load(Ordering::Relaxed);
}
if state & READERS_PARKED == 0 {
if spin.yield_now() {
state = self.state.load(Ordering::Relaxed);
continue;
}
if let Err(e) = self.state.compare_exchange_weak(
state,
state | READERS_PARKED,
Ordering::Relaxed,
Ordering::Relaxed,
) {
state = e;
continue;
}
}
futex::wait(&self.state, state | READERS_PARKED, None);
state = self.state.load(Ordering::Relaxed);
spin.reset();
}
}
#[cold]
pub unsafe fn unlock_read(&self) {
let state = self.state.fetch_sub(READER, Ordering::Release);
debug_assert_eq!(state & READERS_PARKED, 0);
debug_assert_ne!(state & MASK, UNLOCKED);
debug_assert_ne!(state & MASK, WRITER);
if state & (MASK | WRITERS_PARKED) == (READER | WRITERS_PARKED) {
self.unlock_read_slow(state);
}
}
#[cold]
fn unlock_read_slow(&self, mut state: u32) {
state -= READER;
while state & (MASK | WRITERS_PARKED) == (UNLOCKED | WRITERS_PARKED) {
if let Err(e) = self.state.compare_exchange_weak(
state,
state & !WRITERS_PARKED,
Ordering::Relaxed,
Ordering::Relaxed,
) {
state = e;
continue;
}
self.epoch.fetch_add(1, Ordering::Release);
futex::wake(&self.epoch, 1);
return;
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment