Last active
July 17, 2021 02:41
-
-
Save emonkak/7ae8e79ede0abdde5113420dc3580450 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cell::UnsafeCell; | |
use std::future::Future; | |
use std::pin::Pin; | |
use std::ptr; | |
use std::rc::Rc; | |
use std::task::{Context, Poll}; | |
pub struct Generator<'a, Yield, Resume, Return> { | |
coroutine: Coroutine<Yield, Resume>, | |
future: Pin<Box<dyn Future<Output = Return> + 'a>>, | |
} | |
impl<'a, Yield, Resume, Return> Generator<'a, Yield, Resume, Return> { | |
pub fn new<F>(producer: impl FnOnce(Coroutine<Yield, Resume>) -> F + 'a) -> Self | |
where | |
F: Future<Output = Return> + 'a, | |
{ | |
let coroutine = Coroutine::new(); | |
let future = Box::pin(producer(coroutine.clone())); | |
Self { coroutine, future } | |
} | |
pub fn resume(&mut self, arg: Resume) -> GeneratorState<Yield, Return> { | |
debug_assert_eq!(self.coroutine.peek_state(), CoroutineState::Empty); | |
self.coroutine.replace_state(CoroutineState::Resume(arg)); | |
self.coroutine.advance(self.future.as_mut()) | |
} | |
pub fn resume_empty(&mut self) -> GeneratorState<Yield, Return> { | |
debug_assert_eq!(self.coroutine.peek_state(), CoroutineState::Empty); | |
self.coroutine.advance(self.future.as_mut()) | |
} | |
pub fn is_complete(&self) -> bool { | |
self.coroutine.peek_state() == CoroutineState::Done | |
} | |
} | |
impl<'a, Yield, Return> IntoIterator for Generator<'a, Yield, (), Return> { | |
type Item = Yield; | |
type IntoIter = iter::IntoIter<'a, Yield, Return>; | |
fn into_iter(self) -> Self::IntoIter { | |
Self::IntoIter { generator: self } | |
} | |
} | |
#[derive(Debug, Eq, PartialEq)] | |
pub enum GeneratorState<Yield, Return> { | |
Yielded(Yield), | |
Complete(Return), | |
} | |
impl<Yield, Return> GeneratorState<Yield, Return> { | |
pub fn yielded(self) -> Option<Yield> { | |
match self { | |
GeneratorState::Yielded(value) => Some(value), | |
_ => None, | |
} | |
} | |
pub fn complete(self) -> Option<Return> { | |
match self { | |
GeneratorState::Complete(value) => Some(value), | |
_ => None, | |
} | |
} | |
} | |
#[derive(Debug)] | |
pub struct Coroutine<Yield, Resume> { | |
state: Rc<UnsafeCell<CoroutineState<Yield, Resume>>>, | |
} | |
impl<Yield, Resume> Coroutine<Yield, Resume> { | |
fn new() -> Self { | |
Self { | |
state: Rc::new(UnsafeCell::new(CoroutineState::Empty)), | |
} | |
} | |
pub fn suspend(&self, value: Yield) -> impl Future<Output = Resume> { | |
let state = self.peek_state(); | |
match state { | |
CoroutineState::Empty | CoroutineState::Resume(_) => {} | |
_ => panic!("Invalid state: {:?}", state), | |
} | |
self.replace_state(CoroutineState::Suspend(value)); | |
self.clone() | |
} | |
fn peek_state(&self) -> CoroutineState<(), ()> { | |
unsafe { self.state.get().as_ref().unwrap().without_values() } | |
} | |
fn replace_state(&self, next_state: CoroutineState<Yield, Resume>) -> CoroutineState<Yield, Resume> { | |
unsafe { ptr::replace(self.state.get(), next_state) } | |
} | |
fn advance<Return>(&self, future: Pin<&mut dyn Future<Output = Return>>) -> GeneratorState<Yield, Return> { | |
let waker = waker::create(); | |
let mut context = Context::from_waker(&waker); | |
match future.poll(&mut context) { | |
Poll::Pending => { | |
let state = self.replace_state(CoroutineState::Empty); | |
match state { | |
CoroutineState::Suspend(value) => GeneratorState::Yielded(value), | |
_ => panic!("Invalid state: {:?}", state.without_values()), | |
} | |
} | |
Poll::Ready(value) => { | |
self.replace_state(CoroutineState::Done); | |
GeneratorState::Complete(value) | |
} | |
} | |
} | |
fn clone(&self) -> Self { | |
Self { | |
state: Rc::clone(&self.state), | |
} | |
} | |
} | |
impl<Yield, Resume> Future for Coroutine<Yield, Resume> { | |
type Output = Resume; | |
fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> { | |
let state = self.peek_state(); | |
match state { | |
CoroutineState::Suspend(_) => Poll::Pending, | |
CoroutineState::Resume(_) => { | |
let state = self.replace_state(CoroutineState::Empty); | |
match state { | |
CoroutineState::Resume(arg) => Poll::Ready(arg), | |
_ => panic!("Invalid state: {:?}", state.without_values()), | |
} | |
} | |
_ => panic!("Invalid state: {:?}", state), | |
} | |
} | |
} | |
#[derive(Debug, Eq, PartialEq)] | |
enum CoroutineState<Yield, Resume> { | |
Empty, | |
Suspend(Yield), | |
Resume(Resume), | |
Done, | |
} | |
impl<Yield, Resume> CoroutineState<Yield, Resume> { | |
fn without_values(&self) -> CoroutineState<(), ()> { | |
match self { | |
Self::Empty => CoroutineState::Empty, | |
Self::Suspend(_) => CoroutineState::Suspend(()), | |
Self::Resume(_) => CoroutineState::Resume(()), | |
Self::Done => CoroutineState::Done, | |
} | |
} | |
} | |
mod iter { | |
use super::Generator; | |
pub struct IntoIter<'a, Yield, Return> { | |
pub(super) generator: Generator<'a, Yield, (), Return>, | |
} | |
impl<'a, Yield, Return> Iterator for IntoIter<'a, Yield, Return> { | |
type Item = Yield; | |
fn next(&mut self) -> Option<Self::Item> { | |
self.generator.resume(()).yielded() | |
} | |
} | |
} | |
mod waker { | |
use std::task::{RawWaker, RawWakerVTable, Waker}; | |
const RAW_WAKER: RawWaker = RawWaker::new(std::ptr::null(), &VTABLE); | |
const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop); | |
pub fn create() -> Waker { | |
unsafe { Waker::from_raw(RAW_WAKER) } | |
} | |
unsafe fn clone(_: *const ()) -> RawWaker { | |
RAW_WAKER | |
} | |
unsafe fn wake(_: *const ()) {} | |
unsafe fn wake_by_ref(_: *const ()) {} | |
unsafe fn drop(_: *const ()) {} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[test] | |
fn test_resume() { | |
let mut gen = Generator::new(|co| async move { | |
let x = co.suspend("foo").await; | |
let y = co.suspend("foobar").await; | |
let z = co.suspend("foobarbaz").await; | |
[x, y, z] | |
}); | |
assert_eq!(gen.is_complete(), false); | |
let x = gen.resume_empty().yielded().unwrap(); | |
let y = gen.resume(x.len()).yielded().unwrap(); | |
let z = gen.resume(y.len()).yielded().unwrap(); | |
let result = gen.resume(z.len()).complete().unwrap(); | |
assert_eq!(result, [3, 6, 9]); | |
assert_eq!(gen.is_complete(), true); | |
} | |
#[test] | |
fn test_iter() { | |
let odd_numbers_less_than_ten = Generator::new(|co| async move { | |
let mut n = 1; | |
while n < 10 { | |
co.suspend(n).await; | |
n += 2; | |
} | |
() | |
}); | |
assert_eq!( | |
odd_numbers_less_than_ten.into_iter().collect::<Vec<_>>(), | |
[1, 3, 5, 7, 9] | |
); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment