Last active
November 25, 2024 04:09
-
-
Save BurntNail/36a5543738d208305d59cecbd2fb5c9e to your computer and use it in GitHub Desktop.
Basic single-threaded async executor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::future::Future; | |
use std::sync::mpsc::{Sender, channel}; | |
use std::pin::Pin; | |
use std::sync::Arc; | |
use std::sync::atomic::{AtomicBool, Ordering}; | |
use std::task::{RawWaker, Waker, Context, Poll}; | |
use std::thread::JoinHandle; | |
use std::time::{Instant, Duration}; | |
use waker::{WakerData, VTABLE}; | |
mod waker { | |
use std::{ | |
sync::mpsc::Sender, task::{RawWaker, RawWakerVTable} | |
}; | |
#[derive(Debug)] | |
pub struct WakerData { | |
tasks_sender: Sender<usize>, | |
id: usize, | |
} | |
impl WakerData { | |
pub fn new (tasks_sender: Sender<usize>, id: usize) -> Self { | |
Self { | |
tasks_sender, id | |
} | |
} | |
} | |
impl Clone for WakerData { | |
fn clone(&self) -> Self { | |
Self { tasks_sender: self.tasks_sender.clone(), id: self.id } | |
} | |
} | |
pub const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop); | |
unsafe fn clone (data: *const ()) -> RawWaker { | |
let old_data = std::ptr::read(data as *const WakerData); | |
let new_data = old_data.clone(); | |
std::mem::forget(old_data); | |
let boxed = Box::new(new_data); | |
let raw_ptr = Box::into_raw(boxed); | |
RawWaker::new(raw_ptr as *const (), &VTABLE) | |
} | |
unsafe fn wake(data: *const ()) { | |
let data = std::ptr::read(data as *const WakerData); | |
data.tasks_sender | |
.send(data.id) | |
.expect("unable to send task id to executor"); | |
std::mem::drop(data); | |
} | |
unsafe fn wake_by_ref(data: *const ()) { | |
let data = std::ptr::read(data as *const WakerData); | |
data.tasks_sender | |
.send(data.id) | |
.expect("unable to send task id to executor"); | |
std::mem::forget(data); | |
} | |
unsafe fn drop (data: *const ()) { | |
std::ptr::drop_in_place(data as *mut WakerData) | |
} | |
} | |
pub type BoxedFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>; | |
pub struct Executor { | |
new_tasks_sender: Sender<BoxedFuture<u32>>, | |
can_finish: Arc<AtomicBool>, | |
running_thread: JoinHandle<()>, | |
} | |
impl Executor { | |
pub fn start() -> Executor { | |
let (new_tasks_sender, new_tasks_receiver) = channel(); | |
let can_finish = Arc::new(AtomicBool::new(false)); | |
let thread_can_finish = can_finish.clone(); | |
let running_thread = std::thread::spawn(move || { | |
let (tasks_sender, tasks_receiver) = channel(); | |
let mut tasks_to_poll: Vec<Option<BoxedFuture<u32>>> = vec![]; | |
loop { | |
for future in new_tasks_receiver.try_iter() { | |
let index = tasks_to_poll.len(); | |
tasks_to_poll.push(Some(future)); | |
println!("[executor] adding new task @ {index}"); | |
tasks_sender.send(index).unwrap(); | |
} | |
for index in tasks_receiver.try_iter() { | |
if index >= tasks_to_poll.len() { | |
panic!("index out of bounds"); | |
} | |
let waker_data = WakerData::new(tasks_sender.clone(), index); | |
let boxed_waker_data = Box::new(waker_data); | |
let raw_waker_data = Box::into_raw(boxed_waker_data); | |
let raw_waker = | |
RawWaker::new(raw_waker_data as *const WakerData as *const (), &VTABLE); | |
let waker = unsafe { Waker::from_raw(raw_waker) }; | |
let mut cx = Context::from_waker(&waker); | |
if let Some(task) = &mut tasks_to_poll[index] { | |
if let Poll::Ready(res) = task.as_mut().poll(&mut cx) { | |
println!("[executor] Received {res} from {index}"); | |
tasks_to_poll[index] = None; | |
} | |
} | |
} | |
if thread_can_finish.load(Ordering::Relaxed) && tasks_to_poll.iter().all(|x| x.is_none()) { | |
break; | |
} | |
} | |
}); | |
Executor { | |
new_tasks_sender, | |
can_finish, | |
running_thread | |
} | |
} | |
pub fn join (self) { | |
self.can_finish.store(true, Ordering::SeqCst); | |
self.running_thread.join().unwrap(); | |
} | |
pub fn run<F: Future<Output = u32> + Send + 'static> (&self, f: F) { | |
self.new_tasks_sender.send(Box::pin(f)).unwrap(); | |
} | |
} | |
struct TimerFuture { | |
start: Option<Instant>, | |
time: Duration, | |
timeout_ms: u32 | |
} | |
impl TimerFuture { | |
pub fn new (timeout_ms: u32) -> Self { | |
Self { | |
start: None, | |
timeout_ms, | |
time: Duration::from_millis(timeout_ms as u64) | |
} | |
} | |
} | |
impl Future for TimerFuture { | |
type Output = u32; | |
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | |
match self.start { | |
None => { | |
self.start = Some(Instant::now()); | |
}, | |
Some(x) => if x.elapsed() >= self.time { | |
return Poll::Ready(self.timeout_ms); | |
} | |
} | |
cx.waker().wake_by_ref(); | |
Poll::Pending | |
} | |
} | |
fn main() { | |
let executor = Executor::start(); | |
let create_task = |time, id| async move { | |
let fut = TimerFuture::new(time); | |
println!("[task {id}] created future"); | |
let start = Instant::now(); | |
let res = fut.await; | |
let el = start.elapsed(); | |
println!("[task {id}] awaited future, got {res:?} in {el:?}"); | |
res | |
}; | |
let task_1 = create_task(150, 1); | |
let task_2 = create_task(150, 2); | |
let task_3 = create_task(50, 3); | |
let task_4 = create_task(200, 4); | |
let start = Instant::now(); | |
executor.run(task_1); | |
executor.run(task_2); | |
executor.run(task_3); | |
executor.run(task_4); | |
println!("[main] created all tasks, joining executor"); | |
executor.join(); | |
let el = start.elapsed(); | |
println!("[main] joined, took {el:?}"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment