Skip to content

Instantly share code, notes, and snippets.

@emonkak
Last active July 17, 2021 02:41
Show Gist options
  • Save emonkak/7ae8e79ede0abdde5113420dc3580450 to your computer and use it in GitHub Desktop.
Save emonkak/7ae8e79ede0abdde5113420dc3580450 to your computer and use it in GitHub Desktop.
use std::cell::UnsafeCell;
use std::future::Future;
use std::pin::Pin;
use std::ptr;
use std::rc::Rc;
use std::task::{Context, Poll};
pub struct Generator<'a, Yield, Resume, Return> {
coroutine: Coroutine<Yield, Resume>,
future: Pin<Box<dyn Future<Output = Return> + 'a>>,
}
impl<'a, Yield, Resume, Return> Generator<'a, Yield, Resume, Return> {
pub fn new<F>(producer: impl FnOnce(Coroutine<Yield, Resume>) -> F + 'a) -> Self
where
F: Future<Output = Return> + 'a,
{
let coroutine = Coroutine::new();
let future = Box::pin(producer(coroutine.clone()));
Self { coroutine, future }
}
pub fn resume(&mut self, arg: Resume) -> GeneratorState<Yield, Return> {
debug_assert_eq!(self.coroutine.peek_state(), CoroutineState::Empty);
self.coroutine.replace_state(CoroutineState::Resume(arg));
self.coroutine.advance(self.future.as_mut())
}
pub fn resume_empty(&mut self) -> GeneratorState<Yield, Return> {
debug_assert_eq!(self.coroutine.peek_state(), CoroutineState::Empty);
self.coroutine.advance(self.future.as_mut())
}
pub fn is_complete(&self) -> bool {
self.coroutine.peek_state() == CoroutineState::Done
}
}
impl<'a, Yield, Return> IntoIterator for Generator<'a, Yield, (), Return> {
type Item = Yield;
type IntoIter = iter::IntoIter<'a, Yield, Return>;
fn into_iter(self) -> Self::IntoIter {
Self::IntoIter { generator: self }
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum GeneratorState<Yield, Return> {
Yielded(Yield),
Complete(Return),
}
impl<Yield, Return> GeneratorState<Yield, Return> {
pub fn yielded(self) -> Option<Yield> {
match self {
GeneratorState::Yielded(value) => Some(value),
_ => None,
}
}
pub fn complete(self) -> Option<Return> {
match self {
GeneratorState::Complete(value) => Some(value),
_ => None,
}
}
}
#[derive(Debug)]
pub struct Coroutine<Yield, Resume> {
state: Rc<UnsafeCell<CoroutineState<Yield, Resume>>>,
}
impl<Yield, Resume> Coroutine<Yield, Resume> {
fn new() -> Self {
Self {
state: Rc::new(UnsafeCell::new(CoroutineState::Empty)),
}
}
pub fn suspend(&self, value: Yield) -> impl Future<Output = Resume> {
let state = self.peek_state();
match state {
CoroutineState::Empty | CoroutineState::Resume(_) => {}
_ => panic!("Invalid state: {:?}", state),
}
self.replace_state(CoroutineState::Suspend(value));
self.clone()
}
fn peek_state(&self) -> CoroutineState<(), ()> {
unsafe { self.state.get().as_ref().unwrap().without_values() }
}
fn replace_state(&self, next_state: CoroutineState<Yield, Resume>) -> CoroutineState<Yield, Resume> {
unsafe { ptr::replace(self.state.get(), next_state) }
}
fn advance<Return>(&self, future: Pin<&mut dyn Future<Output = Return>>) -> GeneratorState<Yield, Return> {
let waker = waker::create();
let mut context = Context::from_waker(&waker);
match future.poll(&mut context) {
Poll::Pending => {
let state = self.replace_state(CoroutineState::Empty);
match state {
CoroutineState::Suspend(value) => GeneratorState::Yielded(value),
_ => panic!("Invalid state: {:?}", state.without_values()),
}
}
Poll::Ready(value) => {
self.replace_state(CoroutineState::Done);
GeneratorState::Complete(value)
}
}
}
fn clone(&self) -> Self {
Self {
state: Rc::clone(&self.state),
}
}
}
impl<Yield, Resume> Future for Coroutine<Yield, Resume> {
type Output = Resume;
fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
let state = self.peek_state();
match state {
CoroutineState::Suspend(_) => Poll::Pending,
CoroutineState::Resume(_) => {
let state = self.replace_state(CoroutineState::Empty);
match state {
CoroutineState::Resume(arg) => Poll::Ready(arg),
_ => panic!("Invalid state: {:?}", state.without_values()),
}
}
_ => panic!("Invalid state: {:?}", state),
}
}
}
#[derive(Debug, Eq, PartialEq)]
enum CoroutineState<Yield, Resume> {
Empty,
Suspend(Yield),
Resume(Resume),
Done,
}
impl<Yield, Resume> CoroutineState<Yield, Resume> {
fn without_values(&self) -> CoroutineState<(), ()> {
match self {
Self::Empty => CoroutineState::Empty,
Self::Suspend(_) => CoroutineState::Suspend(()),
Self::Resume(_) => CoroutineState::Resume(()),
Self::Done => CoroutineState::Done,
}
}
}
mod iter {
use super::Generator;
pub struct IntoIter<'a, Yield, Return> {
pub(super) generator: Generator<'a, Yield, (), Return>,
}
impl<'a, Yield, Return> Iterator for IntoIter<'a, Yield, Return> {
type Item = Yield;
fn next(&mut self) -> Option<Self::Item> {
self.generator.resume(()).yielded()
}
}
}
mod waker {
use std::task::{RawWaker, RawWakerVTable, Waker};
const RAW_WAKER: RawWaker = RawWaker::new(std::ptr::null(), &VTABLE);
const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
pub fn create() -> Waker {
unsafe { Waker::from_raw(RAW_WAKER) }
}
unsafe fn clone(_: *const ()) -> RawWaker {
RAW_WAKER
}
unsafe fn wake(_: *const ()) {}
unsafe fn wake_by_ref(_: *const ()) {}
unsafe fn drop(_: *const ()) {}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resume() {
let mut gen = Generator::new(|co| async move {
let x = co.suspend("foo").await;
let y = co.suspend("foobar").await;
let z = co.suspend("foobarbaz").await;
[x, y, z]
});
assert_eq!(gen.is_complete(), false);
let x = gen.resume_empty().yielded().unwrap();
let y = gen.resume(x.len()).yielded().unwrap();
let z = gen.resume(y.len()).yielded().unwrap();
let result = gen.resume(z.len()).complete().unwrap();
assert_eq!(result, [3, 6, 9]);
assert_eq!(gen.is_complete(), true);
}
#[test]
fn test_iter() {
let odd_numbers_less_than_ten = Generator::new(|co| async move {
let mut n = 1;
while n < 10 {
co.suspend(n).await;
n += 2;
}
()
});
assert_eq!(
odd_numbers_less_than_ten.into_iter().collect::<Vec<_>>(),
[1, 3, 5, 7, 9]
);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment