-
-
Save Matthias247/632c290cbee977393a9c8fed2bf5be9b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use pin_project::pin_project; | |
use std::future::Future; | |
use std::marker::PhantomData; | |
use std::mem::transmute; | |
use std::pin::Pin; | |
use std::task::{Context, Poll}; | |
use tokio::sync::mpsc; | |
use pin_utils::pin_mut; | |
#[derive(Clone)] | |
pub struct Scope<'env> { | |
handle: tokio::runtime::Handle, | |
chan: mpsc::Sender<()>, | |
_marker: PhantomData<&'env mut &'env ()>, | |
} | |
#[pin_project] | |
pub struct ScopedJoinHandle<'scope, R> { | |
#[pin] | |
handle: tokio::task::JoinHandle<R>, | |
_marker: PhantomData<&'scope ()>, | |
} | |
impl<'env, R> Future for ScopedJoinHandle<'env, R> { | |
type Output = Result<R, tokio::task::JoinError>; | |
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { | |
self.project().handle.poll(cx) | |
} | |
} | |
impl<'env> Scope<'env> { | |
pub fn spawn<'scope, F, R>(&'scope self, fut: F) -> ScopedJoinHandle<'scope, R> | |
where | |
F: Future<Output = R> + Send + 'env, | |
R: Send + 'static, // TODO: weaken to 'env | |
{ | |
let chan = self.chan.clone(); | |
let future_env: Pin<Box<dyn Future<Output = R> + Send + 'env>> = Box::pin(async move { | |
// the cloned channel gets dropped at the end of the future | |
let _chan = chan; | |
fut.await | |
}); | |
// SAFETY: scoped API ensures the spawned tasks will not outlive the parent scope | |
let future_static: Pin<Box<dyn Future<Output = R> + Send + 'static>> = | |
unsafe { transmute(future_env) }; | |
let handle = self.handle.spawn(future_static); | |
ScopedJoinHandle { | |
handle, | |
_marker: PhantomData, | |
} | |
} | |
} | |
// TODO: if `Func` takes a reference to the scope, `scope.spawn` will generate a cryptic error | |
#[doc(hidden)] | |
pub async fn scope_impl<'env, Func, Fut, R>(handle: tokio::runtime::Handle, func: Func) -> R | |
where | |
Func: FnOnce(Scope<'env>) -> Fut, | |
Fut: Future<Output = R> + Send, | |
R: Send, | |
{ | |
// we won't send data through this channel, so reserve the minimal buffer (buffer size must be | |
// greater than 0). | |
let (tx, mut rx) = mpsc::channel(1); | |
let scope = Scope::<'env> { | |
handle, | |
chan: tx, | |
_marker: PhantomData, | |
}; | |
// TODO: `func` and the returned future can panic during the execution. | |
// In that case, we need to cancel all the spawned subtasks forcibly, but we cannot cancel | |
// spawned tasks from the outside of tokio. | |
let result = func(scope).await; | |
// yield the control until all spawned task finish(drop). | |
assert!(rx.recv().await.is_none()); | |
result | |
} | |
#[macro_export] | |
macro_rules! scope { | |
($handle:expr, $func:expr) => {{ | |
crate::scope_impl($handle, $func).await | |
}}; | |
} | |
use std::time::Duration; | |
use tokio::time::delay_for; | |
fn main() { | |
let mut rt = tokio::runtime::Runtime::new().unwrap(); | |
let handle = rt.handle().clone(); | |
rt.block_on(async { | |
let result = { | |
let local = String::from("hello world"); | |
let local = &local; | |
scope!(handle.clone(), |scope| { | |
// TODO: without this `move`, we get a compilation error. why? | |
async move { | |
// this spawned subtask will continue running after the scoped task | |
// finished, but `scope!` will wait until this task completes. | |
scope.spawn(async move { | |
delay_for(Duration::from_millis(500)).await; | |
println!("spanwed task is done: {}", local); | |
}); | |
// since spawned tasks can outlive the scoped task, they cannot have | |
// references to the scoped task's stack | |
// let evil = String::from("may dangle"); | |
// scope.spawn(async { | |
// delay_for(Duration::from_millis(200)).await; | |
// println!("spanwed task cannot access evil: {}", evil); | |
// }); | |
let handle = scope.spawn(async { | |
println!("another spawned task"); | |
}); | |
handle.await.unwrap(); // you can await the returned handle | |
delay_for(Duration::from_millis(100)).await; | |
println!("scoped task is done: {}", local); | |
} | |
}); | |
delay_for(Duration::from_millis(110)).await; | |
println!("local can be used here: {}", local); | |
}; | |
println!("local is freed"); | |
delay_for(Duration::from_millis(600)).await; | |
{ | |
let data = vec![1, 2, 3]; | |
let data_ref = &data; | |
let mut scope_fut = Box::pin(scope_impl(handle.clone(), |scope| async move { | |
scope.spawn(async { | |
println!("Hello from 2nd scope"); | |
println!("Data: {:?}", data_ref); | |
delay_for(Duration::from_millis(500)).await; | |
println!("End from 2nd scope"); | |
println!("Data: {:?}", data_ref); | |
}); | |
5u32 | |
})); | |
futures::poll!(scope_fut.as_mut()); | |
std::mem::forget(scope_fut); | |
} | |
result | |
}); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment