Skip to content

Instantly share code, notes, and snippets.

@eira-fransham
Created September 24, 2021 08:39
Show Gist options
  • Save eira-fransham/0e28081aad36f91655d603089ce8ef31 to your computer and use it in GitHub Desktop.
Save eira-fransham/0e28081aad36f91655d603089ce8ef31 to your computer and use it in GitHub Desktop.
Toy threadpool implementation
use std::{
any, mem, panic,
sync::atomic::{AtomicUsize, Ordering},
thread,
};
#[derive(PartialEq)]
enum Progress {
Continue,
Stop,
}
pub struct Pool {
sender: crossbeam::channel::Sender<Box<dyn FnOnce() -> Progress + Send>>,
receiver: crossbeam::channel::Receiver<Box<dyn FnOnce() -> Progress + Send>>,
handles: Vec<thread::JoinHandle<Result<(), Box<dyn any::Any + Send>>>>,
live_handles: Box<AtomicUsize>,
spawn_count: usize,
}
impl Pool {
#[inline]
pub fn new(thread_count: usize) -> Self {
let (tx, rx) = crossbeam::channel::bounded(4 * thread_count);
let spawn_count = thread_count - 1;
Pool {
sender: tx,
receiver: rx,
handles: vec![],
// Exclude current thread from worker thread count.
spawn_count,
live_handles: AtomicUsize::new(spawn_count).into(),
}
}
#[inline]
fn join(&mut self) {
for _ in 0..self.spawn_count {
self.sender.send(Box::new(|| Progress::Stop)).unwrap();
}
let mut err = None;
for handle in self.handles.drain(..) {
if let Err(e) = handle.join().unwrap() {
err = Some(e);
}
}
if let Some(e) = err {
panic::resume_unwind(e);
}
}
#[inline]
fn wait(&self, barrier: &AtomicUsize) {
while barrier.load(Ordering::SeqCst) != 0 {
if let Ok(task) = self.receiver.try_recv() {
task();
} else {
thread::yield_now();
}
}
assert_eq!(
self.live_handles.load(Ordering::SeqCst),
self.spawn_count,
"Worker thread panicked"
);
}
/// # Safety
///
/// When submitting a task, it is the client's responsibility to pass in references which live long
/// enough and that the memory regions don't overlap at the time of access, as different tasks could
/// operate on them simultaneously.
#[inline]
unsafe fn submit_task<T>(&self, task: T, barrier: &AtomicUsize)
where
T: FnOnce() + Send,
{
barrier.fetch_add(1, Ordering::SeqCst);
let task: Box<dyn FnOnce() -> Progress + Send> = Box::new(move || {
task();
barrier.fetch_sub(1, Ordering::SeqCst);
Progress::Continue
});
self.sender.send(mem::transmute(task)).unwrap();
}
#[inline]
pub fn scope<F>(&self, f: F)
where
F: FnOnce(Scope<'_>),
{
let barrier = &AtomicUsize::default();
let scope = Scope {
inner: self,
barrier,
};
let res = panic::catch_unwind(panic::AssertUnwindSafe(move || f(scope)));
self.wait(barrier);
if let Err(e) = res {
panic::resume_unwind(e);
}
}
#[inline]
pub fn run(&mut self) {
use std::pin::Pin;
for i in 0..self.spawn_count {
let receiver = self.receiver.clone();
let live_handles: Pin<&AtomicUsize> = Pin::new(&*self.live_handles);
// This is unsafe, but we have three options:
// A: The destructor for `Pool` is ran, and therefore this handle is dropped before
// the `Box<AtomicUsize>` is deallocated.
// B: The destructor for `Pool` is _never_ ran, and therefore the `Box` is never
// deallocated.
// C: The destructor for `Pool` is ran but panicks before all threads join. In this
// case the program will abort and so the pointer is not read after freeing.
//
// If `join` is called manually and it panicks, then this is all still true.
let live_handles: Pin<&'static AtomicUsize> = unsafe { mem::transmute(live_handles) };
let handle = thread::Builder::new()
.name(format!("worker-{}", i))
.spawn(move || {
let res = panic::catch_unwind(|| loop {
if let Ok(task) = receiver.recv() {
if Progress::Stop == task() {
break;
}
}
});
live_handles.fetch_sub(1, Ordering::SeqCst);
res
})
.unwrap();
self.handles.push(handle);
}
}
}
impl Drop for Pool {
#[inline]
fn drop(&mut self) {
self.join()
}
}
pub struct Scope<'a> {
inner: &'a Pool,
barrier: &'a AtomicUsize,
}
impl<'a> Scope<'a> {
#[inline]
pub fn spawn<F>(&'a self, f: F)
where
F: FnOnce() + Send + 'a,
{
unsafe { self.inner.submit_task(f, self.barrier) };
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment