Last active
September 1, 2017 17:32
-
-
Save kyren/d010f25cc6d98bcfc4044c44518b5676 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::result::Result; | |
use std::thread::{spawn, JoinHandle}; | |
use std::sync::{Arc, Mutex}; | |
use std::sync::mpsc::{sync_channel, channel, SyncSender, Receiver, TryRecvError, RecvTimeoutError}; | |
use std::boxed::FnBox; | |
use std::time::Duration; | |
/// Holds a piece of data in a particular thread and produces promises that are executed on that | |
/// thread only. Useful for paralellization or concurrency with types that do not implement Send. | |
pub struct ThreadWorker<T> { | |
join: Joiner, | |
send: SyncSender<Visit<T>>, | |
} | |
pub struct ThreadWorkerPromise<T> { | |
join: Joiner, | |
receive: Receiver<T>, | |
} | |
// Makes it easier for any detected worker thread panic to immediately join on the worker thread, | |
// propagating the panic. | |
#[derive(Clone)] | |
struct Joiner(Arc<Mutex<Option<JoinHandle<()>>>>); | |
enum Visit<T> { | |
Stop, | |
Visit(Box<for<'a> FnBox(&'a mut T) + Send>), | |
} | |
impl<T: 'static> ThreadWorker<T> { | |
pub fn new<L>(queue_size: usize, load: L) -> ThreadWorker<T> | |
where | |
L: 'static + FnOnce() -> T + Send, | |
{ | |
let r: Result<_, ()> = Self::do_new(queue_size, || Ok(load())); | |
r.unwrap() | |
} | |
pub fn try_new<L, E>(queue_size: usize, load: L) -> Result<ThreadWorker<T>, E> | |
where | |
L: 'static + FnOnce() -> Result<T, E> + Send, | |
E: 'static + Send, | |
{ | |
Self::do_new(queue_size, load) | |
} | |
pub fn visit<F, R>(&self, op: F) -> ThreadWorkerPromise<R> | |
where | |
F: 'static + FnOnce(&T) -> R + Send, | |
R: 'static + Send, | |
{ | |
self.do_visit(|w: &mut T| op(w)) | |
} | |
pub fn visit_mut<F, R>(&mut self, op: F) -> ThreadWorkerPromise<R> | |
where | |
F: 'static + FnOnce(&mut T) -> R + Send, | |
R: 'static + Send, | |
{ | |
self.do_visit(op) | |
} | |
fn do_new<L, E>(queue_size: usize, load: L) -> Result<ThreadWorker<T>, E> | |
where | |
L: 'static + FnOnce() -> Result<T, E> + Send, | |
E: 'static + Send, | |
{ | |
let (send, receive): (_, Receiver<Visit<T>>) = sync_channel(queue_size); | |
let (load_send, load_receive) = sync_channel(0); | |
let visitor = move || match load() { | |
Err(e) => { | |
load_send.send(Err(e)).expect( | |
"controlling thread died during ThreadWorker load", | |
) | |
} | |
Ok(mut wrapped) => { | |
load_send.send(Ok(())).expect( | |
"controlling thread died during ThreadWorker load", | |
); | |
loop { | |
match receive.recv().expect( | |
"Controlling thread died during ThreadWorker receive", | |
) { | |
Visit::Stop => break, | |
Visit::Visit(v) => { | |
v.call_box((&mut wrapped,)); | |
} | |
} | |
} | |
} | |
}; | |
let join = Joiner::new(spawn(visitor)); | |
match load_receive.recv() { | |
Err(_) => join.propagate_panic(), | |
Ok(Err(e)) => Err(e), | |
Ok(Ok(())) => { | |
Ok(ThreadWorker { | |
join: join, | |
send: send, | |
}) | |
} | |
} | |
} | |
fn do_visit<F, R>(&self, op: F) -> ThreadWorkerPromise<R> | |
where | |
F: 'static + FnOnce(&mut T) -> R + Send, | |
R: 'static + Send, | |
{ | |
let (send, receive) = channel(); | |
self.send | |
.send(Visit::Visit(Box::new(move |w: &mut T| { | |
// If the receiver is dropped, simply ignore this, that should not be considered an | |
// error. | |
let _ = send.send(op(w)); | |
}))) | |
.unwrap(); | |
ThreadWorkerPromise { | |
join: self.join.clone(), | |
receive, | |
} | |
} | |
} | |
impl<T> Drop for ThreadWorker<T> { | |
fn drop(&mut self) { | |
// If the receiver is dropped, then the thread has died anyway. | |
let _ = self.send.send(Visit::Stop); | |
self.join.join(); | |
} | |
} | |
impl<T> ThreadWorkerPromise<T> { | |
pub fn poll(&self) -> Option<T> { | |
match self.receive.try_recv() { | |
Ok(t) => Some(t), | |
Err(TryRecvError::Empty) => None, | |
Err(TryRecvError::Disconnected) => self.join.propagate_panic(), | |
} | |
} | |
pub fn wait(&self) -> T { | |
match self.receive.recv() { | |
Ok(t) => t, | |
Err(_) => self.join.propagate_panic(), | |
} | |
} | |
pub fn wait_timeout(&self, timeout: Duration) -> Option<T> { | |
match self.receive.recv_timeout(timeout) { | |
Ok(t) => Some(t), | |
Err(RecvTimeoutError::Timeout) => None, | |
Err(RecvTimeoutError::Disconnected) => self.join.propagate_panic(), | |
} | |
} | |
} | |
impl Joiner { | |
fn new(join: JoinHandle<()>) -> Joiner { | |
Joiner(Arc::new(Mutex::new(Some(join)))) | |
} | |
fn join(&mut self) { | |
let mut join = self.0.lock().unwrap(); | |
if let Some(join) = join.take() { | |
join.join().unwrap(); | |
} | |
} | |
fn propagate_panic(&self) -> ! { | |
let mut join = self.0.lock().unwrap(); | |
if let Some(join) = join.take() { | |
join.join().unwrap(); | |
panic!("internal error, Threadworker worker thread died without panic") | |
} | |
panic!("internal error, Threadworker worker thread died without panic") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment