Skip to content

Instantly share code, notes, and snippets.

@shinkwhek
Created October 4, 2021 14:10
Show Gist options
  • Save shinkwhek/8f6b796716ac80bd59351eca17893c9f to your computer and use it in GitHub Desktop.
Save shinkwhek/8f6b796716ac80bd59351eca17893c9f to your computer and use it in GitHub Desktop.
// https://github.com/oreilly-japan/conc_ytakano/tree/main/chap3/3.9
use std::ptr::{read_volatile, write_volatile};
use std::sync::atomic::{fence, Ordering};
use std::thread;
const NUM_THREADS: usize = 4;
const NUM_LOOP: usize = 100000;
macro_rules! read_mem {
($addr: expr) => {
unsafe { read_volatile($addr) }
};
}
macro_rules! write_mem {
($addr: expr, $val: expr) => {
unsafe { write_volatile($addr, $val) }
};
}
struct BakeryLock {
entering: [bool; NUM_THREADS],
tickets: [Option<u64>; NUM_THREADS],
}
impl BakeryLock {
fn entering_set(&mut self, idx: usize, b: bool) {
fence(Ordering::SeqCst);
write_mem!(&mut self.entering[idx], b);
fence(Ordering::SeqCst);
}
fn wait(&self, new_ticket: u64, idx: u64, entering: bool, ticket: &Option<u64>) {
while read_mem!(&entering) {}
loop {
match read_mem!(ticket) {
Some(t) => {
if new_ticket < t || (new_ticket == t && idx < 1) {
return;
}
}
None => return,
}
}
}
fn lock(&mut self, idx: usize) -> LockGuard {
self.entering_set(idx, true);
let max = self
.tickets
.iter()
.map(|ticket| ticket.unwrap_or(0))
.max()
.unwrap();
let new_ticket = max + 1;
write_mem!(&mut self.tickets[idx], Some(new_ticket));
self.entering_set(idx, false);
let target = self.entering.iter().zip(self.tickets).enumerate();
let _ = target.map(|(idx, (entering, ticket))| {
self.wait(new_ticket, idx as u64, *entering, &ticket);
});
fence(Ordering::SeqCst);
LockGuard {
idx,
bakerylock: self,
}
}
}
struct LockGuard<'a> {
idx: usize,
bakerylock: &'a mut BakeryLock,
}
impl<'a> Drop for LockGuard<'a> {
fn drop(&mut self) {
fence(Ordering::SeqCst);
write_mem!(&mut self.bakerylock.tickets[self.idx], None);
}
}
static mut COUNT: u64 = 0;
static mut LOCK: BakeryLock = BakeryLock {
entering: [false; NUM_THREADS],
tickets: [None; NUM_THREADS],
};
fn thread_ex(idx: usize) -> impl Fn() {
move || {
for _ in 0..NUM_LOOP {
let _ = unsafe { LOCK.lock(idx) };
unsafe {
let c = read_volatile(&COUNT);
write_volatile(&mut COUNT, c + 1);
}
}
}
}
fn main() {
let v = (0..NUM_THREADS).map(move |idx| thread::spawn(thread_ex(idx)));
for th in v {
th.join().unwrap();
}
println!(
"Count={} (expected={})",
unsafe { COUNT },
NUM_THREADS * NUM_LOOP
);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment