Skip to content

Instantly share code, notes, and snippets.

@BurntNail
Last active November 25, 2024 04:09
Show Gist options
  • Save BurntNail/36a5543738d208305d59cecbd2fb5c9e to your computer and use it in GitHub Desktop.
Save BurntNail/36a5543738d208305d59cecbd2fb5c9e to your computer and use it in GitHub Desktop.
Basic single-threaded async executor
use std::future::Future;
use std::sync::mpsc::{Sender, channel};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{RawWaker, Waker, Context, Poll};
use std::thread::JoinHandle;
use std::time::{Instant, Duration};
use waker::{WakerData, VTABLE};
mod waker {
use std::{
sync::mpsc::Sender, task::{RawWaker, RawWakerVTable}
};
#[derive(Debug)]
pub struct WakerData {
tasks_sender: Sender<usize>,
id: usize,
}
impl WakerData {
pub fn new (tasks_sender: Sender<usize>, id: usize) -> Self {
Self {
tasks_sender, id
}
}
}
impl Clone for WakerData {
fn clone(&self) -> Self {
Self { tasks_sender: self.tasks_sender.clone(), id: self.id }
}
}
pub const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
unsafe fn clone (data: *const ()) -> RawWaker {
let old_data = std::ptr::read(data as *const WakerData);
let new_data = old_data.clone();
std::mem::forget(old_data);
let boxed = Box::new(new_data);
let raw_ptr = Box::into_raw(boxed);
RawWaker::new(raw_ptr as *const (), &VTABLE)
}
unsafe fn wake(data: *const ()) {
let data = std::ptr::read(data as *const WakerData);
data.tasks_sender
.send(data.id)
.expect("unable to send task id to executor");
std::mem::drop(data);
}
unsafe fn wake_by_ref(data: *const ()) {
let data = std::ptr::read(data as *const WakerData);
data.tasks_sender
.send(data.id)
.expect("unable to send task id to executor");
std::mem::forget(data);
}
unsafe fn drop (data: *const ()) {
std::ptr::drop_in_place(data as *mut WakerData)
}
}
pub type BoxedFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
pub struct Executor {
new_tasks_sender: Sender<BoxedFuture<u32>>,
can_finish: Arc<AtomicBool>,
running_thread: JoinHandle<()>,
}
impl Executor {
pub fn start() -> Executor {
let (new_tasks_sender, new_tasks_receiver) = channel();
let can_finish = Arc::new(AtomicBool::new(false));
let thread_can_finish = can_finish.clone();
let running_thread = std::thread::spawn(move || {
let (tasks_sender, tasks_receiver) = channel();
let mut tasks_to_poll: Vec<Option<BoxedFuture<u32>>> = vec![];
loop {
for future in new_tasks_receiver.try_iter() {
let index = tasks_to_poll.len();
tasks_to_poll.push(Some(future));
println!("[executor] adding new task @ {index}");
tasks_sender.send(index).unwrap();
}
for index in tasks_receiver.try_iter() {
if index >= tasks_to_poll.len() {
panic!("index out of bounds");
}
let waker_data = WakerData::new(tasks_sender.clone(), index);
let boxed_waker_data = Box::new(waker_data);
let raw_waker_data = Box::into_raw(boxed_waker_data);
let raw_waker =
RawWaker::new(raw_waker_data as *const WakerData as *const (), &VTABLE);
let waker = unsafe { Waker::from_raw(raw_waker) };
let mut cx = Context::from_waker(&waker);
if let Some(task) = &mut tasks_to_poll[index] {
if let Poll::Ready(res) = task.as_mut().poll(&mut cx) {
println!("[executor] Received {res} from {index}");
tasks_to_poll[index] = None;
}
}
}
if thread_can_finish.load(Ordering::Relaxed) && tasks_to_poll.iter().all(|x| x.is_none()) {
break;
}
}
});
Executor {
new_tasks_sender,
can_finish,
running_thread
}
}
pub fn join (self) {
self.can_finish.store(true, Ordering::SeqCst);
self.running_thread.join().unwrap();
}
pub fn run<F: Future<Output = u32> + Send + 'static> (&self, f: F) {
self.new_tasks_sender.send(Box::pin(f)).unwrap();
}
}
struct TimerFuture {
start: Option<Instant>,
time: Duration,
timeout_ms: u32
}
impl TimerFuture {
pub fn new (timeout_ms: u32) -> Self {
Self {
start: None,
timeout_ms,
time: Duration::from_millis(timeout_ms as u64)
}
}
}
impl Future for TimerFuture {
type Output = u32;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.start {
None => {
self.start = Some(Instant::now());
},
Some(x) => if x.elapsed() >= self.time {
return Poll::Ready(self.timeout_ms);
}
}
cx.waker().wake_by_ref();
Poll::Pending
}
}
fn main() {
let executor = Executor::start();
let create_task = |time, id| async move {
let fut = TimerFuture::new(time);
println!("[task {id}] created future");
let start = Instant::now();
let res = fut.await;
let el = start.elapsed();
println!("[task {id}] awaited future, got {res:?} in {el:?}");
res
};
let task_1 = create_task(150, 1);
let task_2 = create_task(150, 2);
let task_3 = create_task(50, 3);
let task_4 = create_task(200, 4);
let start = Instant::now();
executor.run(task_1);
executor.run(task_2);
executor.run(task_3);
executor.run(task_4);
println!("[main] created all tasks, joining executor");
executor.join();
let el = start.elapsed();
println!("[main] joined, took {el:?}");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment