Skip to content

Instantly share code, notes, and snippets.

@oliver-giersch
Last active June 21, 2019 08:31
Show Gist options
  • Save oliver-giersch/e81d51ebbdbe438ef775d0e2c5222ab5 to your computer and use it in GitHub Desktop.
Save oliver-giersch/e81d51ebbdbe438ef775d0e2c5222ab5 to your computer and use it in GitHub Desktop.
A minimal scoped thread implementation using the unstable "thread_spawn_unchecked" feature that allows join handles to safely leak from a scope
#![feature(thread_spawn_unchecked)]
use std::any::Any;
use std::cell::UnsafeCell;
use std::io;
use std::marker::PhantomData;
use std::mem;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread::{self, JoinHandle};
pub fn scope<'env, F, R>(f: F) -> Result<R, Vec<Box<dyn Any + Send>>>
where
for <'scope> F: FnOnce(&'scope Scope<'env>) -> R + 'env,
R: 'env,
{
let scope = Scope::new();
// if `f` panics, the scope will be dropped and all spawned threads will be joined
let res = f(&scope);
scope.join_all().map(|_| res)
}
pub struct Scope<'env> {
joins: Mutex<Vec<Arc<dyn Join + 'env>>>,
_marker: PhantomData<&'env ()>,
}
impl<'env> Scope<'env> {
pub fn builder(&self) -> ScopedThreadBuilder<'env, '_> {
ScopedThreadBuilder(self, thread::Builder::new())
}
pub fn spawn<F, T>(&self, f: F) -> ScopedJoinHandle<T>
where
F: FnOnce() -> T + Send + 'env,
T: Send + 'env
{
self.builder().spawn(f).unwrap()
}
fn new() -> Self {
Self {
joins: Mutex::new(Vec::new()),
_marker: PhantomData,
}
}
fn join_all(self) -> Result<(), Vec<Box<dyn Any + Send + 'static>>> {
let panics = self.joins
.lock()
.unwrap()
.drain(..)
.filter_map(|handle| handle.join().err())
.collect::<Vec<_>>();
if panics.len() == 0 {
Ok(())
} else {
Err(panics)
}
}
}
impl<'env> Drop for Scope<'env> {
fn drop(&mut self) {
for handle in self.joins.lock().unwrap().drain(..) {
let _ = handle.join();
}
}
}
pub struct ScopedThreadBuilder<'env, 'scope>(&'scope Scope<'env>, thread::Builder);
impl<'env, 'scope> ScopedThreadBuilder<'env, 'scope> {
pub fn name(mut self, name: String) -> Self {
self.1 = self.1.name(name);
self
}
pub fn stack_size(mut self, size: usize) -> Self {
self.1 = self.1.stack_size(size);
self
}
pub fn spawn<F, T>(self, f: F) -> io::Result<ScopedJoinHandle<T>>
where
F: FnOnce() -> T + Send + 'env,
T: Send + 'env,
{
let ScopedThreadBuilder(scope, builder) = self;
let res = unsafe { builder.spawn_unchecked(f) };
res.map(|handle| {
let inner = Arc::new(UnsafeCell::new(JoinState::Unjoined(handle)));
scope.joins.lock().unwrap().push(Arc::clone(&inner) as _);
ScopedJoinHandle(inner)
})
}
}
pub struct ScopedJoinHandle<T>(Arc<UnsafeCell<JoinState<T>>>);
enum JoinState<T> {
Unjoined(JoinHandle<T>),
ScopeJoined(thread::Result<T>),
Joined,
}
impl<T> ScopedJoinHandle<T> {
pub fn join(self) -> thread::Result<T> {
let state = unsafe { &mut *self.0.get() };
match mem::replace(state, JoinState::Joined) {
JoinState::Unjoined(handle) => handle.join(),
JoinState::ScopeJoined(res) => res,
JoinState::Joined => unreachable!(),
}
}
}
trait Join {
fn join(self: Arc<Self>) -> thread::Result<()>;
}
impl<T> Join for UnsafeCell<JoinState<T>> {
fn join(self: Arc<Self>) -> thread::Result<()> {
let state = unsafe { &mut *self.get() };
match mem::replace(state, JoinState::Joined) {
JoinState::Unjoined(handle) => {
let res = handle.join();
// join handle is already dropped
if Arc::strong_count(&self) == 1 {
res.map(|_| ())
} else {
let ret = match res.as_ref() {
Ok(_) => Ok(()),
Err(_) => {
Err(Box::new(String::from("unspecified thread panicked")) as _)
}
};
*state = JoinState::ScopeJoined(res);
ret
}
},
JoinState::Joined => Ok(()),
JoinState::ScopeJoined(_) => unreachable!(),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
#[test]
fn simple() {
let mut a = 0;
scope(|scope| {
scope.spawn(|| {
a = 1;
});
})
.unwrap();
assert_eq!(a, 1);
}
#[test]
fn multiple_writers() {
let count = Mutex::new(0);
scope(|scope| {
scope.spawn(|| *count.lock().unwrap() += 1);
scope.spawn(|| *count.lock().unwrap() += 1);
scope.spawn(|| *count.lock().unwrap() += 1);
scope.spawn(|| *count.lock().unwrap() += 1);
scope.spawn(|| *count.lock().unwrap() += 1);
})
.unwrap();
assert_eq!(*count.lock().unwrap(), 5);
}
#[test]
fn manual_join() {
let count = Mutex::new(0);
scope(|scope| {
let handles = (0..5)
.map(|_| scope.spawn(|| *count.lock().unwrap() += 1))
.collect::<Vec<_>>();
for handle in handles {
let _ = handle.join();
}
assert_eq!(*count.lock().unwrap(), 5);
})
.unwrap();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment