Skip to content

Instantly share code, notes, and snippets.

@kyren
Last active September 1, 2017 17:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kyren/d010f25cc6d98bcfc4044c44518b5676 to your computer and use it in GitHub Desktop.
Save kyren/d010f25cc6d98bcfc4044c44518b5676 to your computer and use it in GitHub Desktop.
use std::result::Result;
use std::thread::{spawn, JoinHandle};
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{sync_channel, channel, SyncSender, Receiver, TryRecvError, RecvTimeoutError};
use std::boxed::FnBox;
use std::time::Duration;
/// Holds a piece of data in a particular thread and produces promises that are executed on that
/// thread only. Useful for paralellization or concurrency with types that do not implement Send.
pub struct ThreadWorker<T> {
join: Joiner,
send: SyncSender<Visit<T>>,
}
pub struct ThreadWorkerPromise<T> {
join: Joiner,
receive: Receiver<T>,
}
// Makes it easier for any detected worker thread panic to immediately join on the worker thread,
// propagating the panic.
#[derive(Clone)]
struct Joiner(Arc<Mutex<Option<JoinHandle<()>>>>);
enum Visit<T> {
Stop,
Visit(Box<for<'a> FnBox(&'a mut T) + Send>),
}
impl<T: 'static> ThreadWorker<T> {
pub fn new<L>(queue_size: usize, load: L) -> ThreadWorker<T>
where
L: 'static + FnOnce() -> T + Send,
{
let r: Result<_, ()> = Self::do_new(queue_size, || Ok(load()));
r.unwrap()
}
pub fn try_new<L, E>(queue_size: usize, load: L) -> Result<ThreadWorker<T>, E>
where
L: 'static + FnOnce() -> Result<T, E> + Send,
E: 'static + Send,
{
Self::do_new(queue_size, load)
}
pub fn visit<F, R>(&self, op: F) -> ThreadWorkerPromise<R>
where
F: 'static + FnOnce(&T) -> R + Send,
R: 'static + Send,
{
self.do_visit(|w: &mut T| op(w))
}
pub fn visit_mut<F, R>(&mut self, op: F) -> ThreadWorkerPromise<R>
where
F: 'static + FnOnce(&mut T) -> R + Send,
R: 'static + Send,
{
self.do_visit(op)
}
fn do_new<L, E>(queue_size: usize, load: L) -> Result<ThreadWorker<T>, E>
where
L: 'static + FnOnce() -> Result<T, E> + Send,
E: 'static + Send,
{
let (send, receive): (_, Receiver<Visit<T>>) = sync_channel(queue_size);
let (load_send, load_receive) = sync_channel(0);
let visitor = move || match load() {
Err(e) => {
load_send.send(Err(e)).expect(
"controlling thread died during ThreadWorker load",
)
}
Ok(mut wrapped) => {
load_send.send(Ok(())).expect(
"controlling thread died during ThreadWorker load",
);
loop {
match receive.recv().expect(
"Controlling thread died during ThreadWorker receive",
) {
Visit::Stop => break,
Visit::Visit(v) => {
v.call_box((&mut wrapped,));
}
}
}
}
};
let join = Joiner::new(spawn(visitor));
match load_receive.recv() {
Err(_) => join.propagate_panic(),
Ok(Err(e)) => Err(e),
Ok(Ok(())) => {
Ok(ThreadWorker {
join: join,
send: send,
})
}
}
}
fn do_visit<F, R>(&self, op: F) -> ThreadWorkerPromise<R>
where
F: 'static + FnOnce(&mut T) -> R + Send,
R: 'static + Send,
{
let (send, receive) = channel();
self.send
.send(Visit::Visit(Box::new(move |w: &mut T| {
// If the receiver is dropped, simply ignore this, that should not be considered an
// error.
let _ = send.send(op(w));
})))
.unwrap();
ThreadWorkerPromise {
join: self.join.clone(),
receive,
}
}
}
impl<T> Drop for ThreadWorker<T> {
fn drop(&mut self) {
// If the receiver is dropped, then the thread has died anyway.
let _ = self.send.send(Visit::Stop);
self.join.join();
}
}
impl<T> ThreadWorkerPromise<T> {
pub fn poll(&self) -> Option<T> {
match self.receive.try_recv() {
Ok(t) => Some(t),
Err(TryRecvError::Empty) => None,
Err(TryRecvError::Disconnected) => self.join.propagate_panic(),
}
}
pub fn wait(&self) -> T {
match self.receive.recv() {
Ok(t) => t,
Err(_) => self.join.propagate_panic(),
}
}
pub fn wait_timeout(&self, timeout: Duration) -> Option<T> {
match self.receive.recv_timeout(timeout) {
Ok(t) => Some(t),
Err(RecvTimeoutError::Timeout) => None,
Err(RecvTimeoutError::Disconnected) => self.join.propagate_panic(),
}
}
}
impl Joiner {
fn new(join: JoinHandle<()>) -> Joiner {
Joiner(Arc::new(Mutex::new(Some(join))))
}
fn join(&mut self) {
let mut join = self.0.lock().unwrap();
if let Some(join) = join.take() {
join.join().unwrap();
}
}
fn propagate_panic(&self) -> ! {
let mut join = self.0.lock().unwrap();
if let Some(join) = join.take() {
join.join().unwrap();
panic!("internal error, Threadworker worker thread died without panic")
}
panic!("internal error, Threadworker worker thread died without panic")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment