Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
quick and dirty impl of lamports bakery algorithm in rust as an example. untested.
//! Implementation of [Lamport's bakery algorithm][bakery]. This is somewhat
//! interesting, because it's a mutex which can be implemented even if the
//! target only has atomic load/store, and no CAS (in principal it can be even
//! more general than this).
//!
//! [bakery]: https://en.wikipedia.org/wiki/Lamport%27s_bakery_algorithm
//!
//! Major caveat: This is not tested, and this algo is no longer appropriate
//! for modern code. Some variations of it can be useful in a thread pool, but
//! it's mostly useful to understand the concepts.
//!
//! ## Notes:
//!
//! Debug asserts are there for easy-to-check things, but do not (and cannot)
//! detect all possible misuse.
//!
//! Thread index could be stored in a thread local, but it's probably better to
//! just pass it into the thread on construction. In practice, for modern
//! applications this is quite inconvenient, and so this algorithm is hard to
//! use, even in situations where it could be appropriate.
//!
//! P.S. Reducing from SeqCst is left as an exercise for the reader.
use core::sync::atomic::{*, Ordering::*};
pub type ThreadIndex = core::num::NonZeroUsize;
pub struct BakeryMutex<const MAX_THREADS: usize> {
entering: [AtomicBool; MAX_THREADS],
threads: [AtomicUsize; MAX_THREADS],
}
impl<const MAX_THREADS: usize> BakeryMutex<MAX_THREADS> {
const INIT: Self = {
const FALSE: AtomicBool = AtomicBool::new(false);
const ZERO: AtomicUsize = AtomicUsize::new(0);
Self {
entering: [FALSE; MAX_THREADS],
threads: [ZERO; MAX_THREADS],
}
};
#[inline]
pub const fn new() -> Self {
Self::INIT
}
/// SAFETY: `thread_index` must be in range `1..=MAX_THREADS` which is
/// unique to your thread. Current thread also must not already hold the
/// lock, etc.
pub unsafe fn lock(&self, thread_index: ThreadIndex) {
assert!(MAX_THREADS != 0);
debug_assert!(
thread_index.get() <= MAX_THREADS,
"out of range 1..={MAX_THREADS:?}: {thread_index:?}",
);
let thread_index = thread_index.get() - 1;
debug_assert!(!self.entering[thread_index].load(Relaxed));
debug_assert_eq!(!self.threads[thread_index].load(Relaxed), 0);
self.entering[thread_index].store(true, SeqCst);
// Note: Panicing here will deadlock everybody, which is probably not
// desirable. Overflow is not allowed, though. In practice, it's
// unlikely for this to get above MAX_THREADS by much, so this is
// probably a theoretical concern.
let ticket = self.threads.iter().map(|t| t.load(SeqCst)).max().unwrap_or_default().checked_add(1).unwrap();
self.threads[thread_index].store(ticket, SeqCst);
self.entering[thread_index].store(false, SeqCst);
for (other_thread_index, (entering, other_ticket)) in self.entering.iter().zip(&self.threads).enumerate() {
while entering.load(SeqCst) {
core::hint::spin_loop();
}
// Wait for our turn. If another thread has the same ticket value as
// us (which is possible), break the tie in favor of the thread with
// the lower index.
while {
other_ticket.load(SeqCst) != 0 &&
(other_ticket.load(SeqCst), other_thread_index) < (ticket, thread_index)
} {
core::hint::spin_loop();
}
}
}
/// SAFETY: `thread_index` must be in range `1..=MAX_THREADS` which is
/// unique to your thread. Current thread must hold lock.
pub unsafe fn unlock(&self, thread_index: ThreadIndex) {
debug_assert!(
thread_index.get() <= MAX_THREADS,
"out of range 1..={MAX_THREADS:?}: {thread_index:?}",
);
let thread_index = thread_index.get() - 1;
debug_assert!(!self.entering[thread_index].load(Relaxed));
debug_assert_ne!(!self.threads[thread_index].load(Relaxed), 0);
self.threads[thread_index].store(0, SeqCst);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment