Created
September 24, 2021 08:39
-
-
Save eira-fransham/0e28081aad36f91655d603089ce8ef31 to your computer and use it in GitHub Desktop.
Toy threadpool implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
any, mem, panic, | |
sync::atomic::{AtomicUsize, Ordering}, | |
thread, | |
}; | |
#[derive(PartialEq)] | |
enum Progress { | |
Continue, | |
Stop, | |
} | |
pub struct Pool { | |
sender: crossbeam::channel::Sender<Box<dyn FnOnce() -> Progress + Send>>, | |
receiver: crossbeam::channel::Receiver<Box<dyn FnOnce() -> Progress + Send>>, | |
handles: Vec<thread::JoinHandle<Result<(), Box<dyn any::Any + Send>>>>, | |
live_handles: Box<AtomicUsize>, | |
spawn_count: usize, | |
} | |
impl Pool { | |
#[inline] | |
pub fn new(thread_count: usize) -> Self { | |
let (tx, rx) = crossbeam::channel::bounded(4 * thread_count); | |
let spawn_count = thread_count - 1; | |
Pool { | |
sender: tx, | |
receiver: rx, | |
handles: vec![], | |
// Exclude current thread from worker thread count. | |
spawn_count, | |
live_handles: AtomicUsize::new(spawn_count).into(), | |
} | |
} | |
#[inline] | |
fn join(&mut self) { | |
for _ in 0..self.spawn_count { | |
self.sender.send(Box::new(|| Progress::Stop)).unwrap(); | |
} | |
let mut err = None; | |
for handle in self.handles.drain(..) { | |
if let Err(e) = handle.join().unwrap() { | |
err = Some(e); | |
} | |
} | |
if let Some(e) = err { | |
panic::resume_unwind(e); | |
} | |
} | |
#[inline] | |
fn wait(&self, barrier: &AtomicUsize) { | |
while barrier.load(Ordering::SeqCst) != 0 { | |
if let Ok(task) = self.receiver.try_recv() { | |
task(); | |
} else { | |
thread::yield_now(); | |
} | |
} | |
assert_eq!( | |
self.live_handles.load(Ordering::SeqCst), | |
self.spawn_count, | |
"Worker thread panicked" | |
); | |
} | |
/// # Safety | |
/// | |
/// When submitting a task, it is the client's responsibility to pass in references which live long | |
/// enough and that the memory regions don't overlap at the time of access, as different tasks could | |
/// operate on them simultaneously. | |
#[inline] | |
unsafe fn submit_task<T>(&self, task: T, barrier: &AtomicUsize) | |
where | |
T: FnOnce() + Send, | |
{ | |
barrier.fetch_add(1, Ordering::SeqCst); | |
let task: Box<dyn FnOnce() -> Progress + Send> = Box::new(move || { | |
task(); | |
barrier.fetch_sub(1, Ordering::SeqCst); | |
Progress::Continue | |
}); | |
self.sender.send(mem::transmute(task)).unwrap(); | |
} | |
#[inline] | |
pub fn scope<F>(&self, f: F) | |
where | |
F: FnOnce(Scope<'_>), | |
{ | |
let barrier = &AtomicUsize::default(); | |
let scope = Scope { | |
inner: self, | |
barrier, | |
}; | |
let res = panic::catch_unwind(panic::AssertUnwindSafe(move || f(scope))); | |
self.wait(barrier); | |
if let Err(e) = res { | |
panic::resume_unwind(e); | |
} | |
} | |
#[inline] | |
pub fn run(&mut self) { | |
use std::pin::Pin; | |
for i in 0..self.spawn_count { | |
let receiver = self.receiver.clone(); | |
let live_handles: Pin<&AtomicUsize> = Pin::new(&*self.live_handles); | |
// This is unsafe, but we have three options: | |
// A: The destructor for `Pool` is ran, and therefore this handle is dropped before | |
// the `Box<AtomicUsize>` is deallocated. | |
// B: The destructor for `Pool` is _never_ ran, and therefore the `Box` is never | |
// deallocated. | |
// C: The destructor for `Pool` is ran but panicks before all threads join. In this | |
// case the program will abort and so the pointer is not read after freeing. | |
// | |
// If `join` is called manually and it panicks, then this is all still true. | |
let live_handles: Pin<&'static AtomicUsize> = unsafe { mem::transmute(live_handles) }; | |
let handle = thread::Builder::new() | |
.name(format!("worker-{}", i)) | |
.spawn(move || { | |
let res = panic::catch_unwind(|| loop { | |
if let Ok(task) = receiver.recv() { | |
if Progress::Stop == task() { | |
break; | |
} | |
} | |
}); | |
live_handles.fetch_sub(1, Ordering::SeqCst); | |
res | |
}) | |
.unwrap(); | |
self.handles.push(handle); | |
} | |
} | |
} | |
impl Drop for Pool { | |
#[inline] | |
fn drop(&mut self) { | |
self.join() | |
} | |
} | |
pub struct Scope<'a> { | |
inner: &'a Pool, | |
barrier: &'a AtomicUsize, | |
} | |
impl<'a> Scope<'a> { | |
#[inline] | |
pub fn spawn<F>(&'a self, f: F) | |
where | |
F: FnOnce() + Send + 'a, | |
{ | |
unsafe { self.inner.submit_task(f, self.barrier) }; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment