Skip to content

Instantly share code, notes, and snippets.

@bonzini
Last active March 11, 2022 21:39
Show Gist options
  • Save bonzini/78f37bd562e1e18f7bd214dd94bcbea7 to your computer and use it in GitHub Desktop.
Save bonzini/78f37bd562e1e18f7bd214dd94bcbea7 to your computer and use it in GitHub Desktop.
Simple C++ coroutine runtime
#include "coro.h"
#include <cstdio>
static __thread Coroutine *current;
void Yield::await_suspend(std::coroutine_handle<> parent) const noexcept {
//printf("!!!! top = %p, yielding from %p\n", current->top, parent);
current->top = parent;
}
// RAII wrapper to set and restore the current coroutine
struct WithCurrent {
Coroutine &_co;
WithCurrent(Coroutine &co): _co(co) {
_co.caller = current;
current = &_co;
}
~WithCurrent() {
Coroutine *co = current;
current = _co.caller;
_co.caller = nullptr;
}
};
void Coroutine::resume() {
auto w = WithCurrent(*this);
std::coroutine_handle<> old_top = top;
//printf("$$$$ resume %p %d\n", old_top.address(), old_top.done());
top = nullptr;
old_top.resume();
}
// ---------------------------
#include <cstdio>
void qemu_coroutine_enter(Coroutine *co)
{
co->resume();
if (!co->top) {
//printf("$$$$ deleting\n");
delete co;
}
}
// Change the type from CoroutineFn<void> to Coroutine,
// so that it does not start until qemu_coroutine_enter()
Coroutine coroutine_trampoline(CoroutineFunc *func, void *opaque)
{
co_await func(opaque);
}
Coroutine *qemu_coroutine_create(CoroutineFunc *func, void *opaque)
{
return new Coroutine(coroutine_trampoline(func, opaque));
}
#pragma once
#include <cstdint>
#include <cstdio>
#include <coroutine>
#include <exception>
struct Coroutine;
extern "C" {
void qemu_coroutine_enter(Coroutine *co);
}
// BaseCoroutine is a simple wrapper type for a Promise. It mostly
// exists because C++ says so, but it also provides two extra features:
// RAII destruction of the coroutine (which is more efficient but
// beware, the promise's final_suspend must always suspend to avoid
// double free) and a cast to std::coroutine_handle<>, which makes
// it resumable.
template<typename Promise> struct BaseCoroutine
{
using promise_type = Promise;
BaseCoroutine() = default;
explicit BaseCoroutine (Promise &promise) :
_coroutine{std::coroutine_handle<Promise>::from_promise(promise)} {}
BaseCoroutine(BaseCoroutine const&) = delete;
BaseCoroutine(BaseCoroutine&& other) : _coroutine{other._coroutine} {
other._coroutine = nullptr;
}
BaseCoroutine& operator=(BaseCoroutine const&) = delete;
BaseCoroutine& operator=(BaseCoroutine&& other) {
if (&other != this) {
_coroutine = other._coroutine;
other._coroutine = nullptr;
}
return *this;
}
~BaseCoroutine() {
//printf("!!!! destroying %p\n", _coroutine);
if (_coroutine) _coroutine.destroy();
}
operator bool() const noexcept {
return _coroutine;
}
operator std::coroutine_handle<>() const noexcept {
return _coroutine;
}
Promise &promise() const noexcept {
return _coroutine.promise();
}
private:
std::coroutine_handle<Promise> _coroutine = nullptr;
};
// This is a simple awaitable object that takes care of resuming a
// parent coroutine. It's needed because co_await suspends all
// parent coroutines on the stack. It does not need a specific
// "kind" of coroutine_handle, so no need to put it inside the
// templates below.
//
// If next is NULL, then this degrades to std::suspend_always.
struct ResumeAndFinish {
explicit ResumeAndFinish(std::coroutine_handle<> next) noexcept :
_next{next} {}
bool await_ready() const noexcept {
return false;
}
bool await_suspend(std::coroutine_handle<> ch) const noexcept {
if (_next) {
_next.resume();
}
return true;
}
void await_resume() const noexcept {}
private:
std::coroutine_handle<> _next;
};
// ------------------------
// Coroutine is the entry point into a coroutine. It stores the
// coroutine_handle that last called qemu_coroutine_yield(), and
// Coroutine::resume() then resumes from the last yield point.
//
// Together with a thread-local variable "current", the "caller"
// member establishes a stack of active coroutines, so that
// qemu_coroutine_yield() knows which coroutine has yielded.
//
// Its promise type, EntryPromise, is pretty much bog-standard.
// It always suspends on entry, so that the coroutine is only
// entered by the first call to qemu_coroutine_enter(); and it
// always suspends on exit too, because we want to clean up the
// coroutine explicitly in BaseCoroutine's destructor.
struct EntryPromise;
struct Coroutine: BaseCoroutine<EntryPromise> {
Coroutine *caller = nullptr;
std::coroutine_handle<> top;
explicit Coroutine(promise_type &promise) :
BaseCoroutine{promise}, top{*this} {}
void resume();
};
struct EntryPromise
{
Coroutine get_return_object() noexcept { return Coroutine{*this}; }
void unhandled_exception() { std::terminate(); }
auto initial_suspend() const noexcept { return std::suspend_always{}; }
auto final_suspend() const noexcept { return std::suspend_always{}; }
void return_void() const noexcept {}
};
// ------------------------
// CoroutineFn does not even need anything more than what
// BaseCoroutine provides, so it's just a type alias. The magic
// is all in ValuePromise<T>.
//
// Suspended CoroutineFns are chained between themselves. Whenever a
// coroutine is suspended, all those that have done a co_await are
// also suspended, and whenever a coroutine finishes, it has to
// check if its parent can now be resumed.
//
// The two auxiliary classes Awaiter and ResumeAndFinish take
// care of the two sides of this. Awaiter's await_suspend() stores
// the parent coroutine into ValuePromise; ResumeAndFinish's runs
// after a coroutine returns, and resumes the parent coroutine.
template<typename T> struct ValuePromise;
template<typename T>
using CoroutineFn = BaseCoroutine<ValuePromise<T>>;
typedef CoroutineFn<void> CoroutineFunc(void *);
// Unfortunately it is forbidden to define both return_void() and
// return_value() in the same class. In order to cut on the
// code duplication, define a superclass for both ValuePromise<T>
// and ValuePromise<void>.
//
// The "curiously recurring template pattern" is used to substitute
// ValuePromise<T> into the methods of the base class and its Awaited.
// For example await_resume() needs to retrieve a value with the
// correct type from the subclass's value() method.
template<typename T, typename Derived>
struct BasePromise
{
using coro_handle_type = std::coroutine_handle<Derived>;
#if 0
// Same as get_return_object().address() but actually works.
// Useful as an identifier to identify the promise in debugging
// output, because it matches the values passed to await_suspend().
void *coro_address() const {
return __builtin_coro_promise((char *)this, __alignof(*this), true);
}
BasePromise() {
printf("!!!! created %p\n", coro_address());
}
~BasePromise() {
printf("!!!! destroyed %p\n", coro_address());
}
#endif
CoroutineFn<T> get_return_object() noexcept { return CoroutineFn<T>{downcast()}; }
void unhandled_exception() { std::terminate(); }
auto initial_suspend() const noexcept { return std::suspend_never{}; }
auto final_suspend() noexcept {
auto continuation = ResumeAndFinish{_next};
mark_ready();
return continuation;
}
private:
std::coroutine_handle<> _next = nullptr;
static const std::uintptr_t READY_MARKER = 1;
void mark_ready() {
_next = std::coroutine_handle<>::from_address((void *)READY_MARKER);
}
bool is_ready() const {
return _next.address() == (void *)READY_MARKER;
}
Derived& downcast() noexcept { return *static_cast<Derived*>(this); }
Derived const& downcast() const noexcept { return *static_cast<const Derived*>(this); }
// This records the parent coroutine, before a co_await suspends
// all parent coroutines on the stack.
void then(std::coroutine_handle<> parent) { _next = parent; }
// This is the object that lets us co_await a CoroutineFn<T> (of which
// this class is the corresponding promise object). This is just mapping
// C++ awaitable naming into the more conventional promise naming.
struct Awaiter {
Derived &_promise;
explicit Awaiter(Derived &promise) : _promise{promise} {}
bool await_ready() const noexcept {
return _promise.is_ready();
}
void await_suspend(std::coroutine_handle<> parent) const noexcept {
_promise.then(parent);
}
Derived::await_resume_type await_resume() const noexcept {
return _promise.value();
}
};
// C++ connoisseurs will tell you that this is not private.
friend Awaiter operator co_await(CoroutineFn<T> co) {
return Awaiter{co.promise()};
}
};
// The actual promises, respectively for non-void and void types.
// All that's left is storing and retrieving the value.
template<typename T>
struct ValuePromise: BasePromise<T, ValuePromise<T>>
{
using await_resume_type = T&&;
T _value;
void return_value(T&& value) { _value = std::move(value); }
void return_value(T const& value) { _value = value; }
T&& value() noexcept { return static_cast<T&&>(_value); }
};
template<>
struct ValuePromise<void>: BasePromise<void, ValuePromise<void>>
{
using await_resume_type = void;
void return_void() const {}
void value() const {}
};
// ---------------------------
// This class takes care of yielding, which is just a matter of doing
// "co_await Yield{}". This always suspends, and also stores the
// suspending CoroutineFn in current->top.
struct Yield: std::suspend_always {
void await_suspend(std::coroutine_handle<> parent) const noexcept;
};
// ---------------------------
Coroutine *qemu_coroutine_create(CoroutineFunc *func, void *opaque);
// Make it possible to write "co_await qemu_coroutine_yield()"
static inline Yield qemu_coroutine_yield()
{
return Yield{};
}
#include "coro.h"
#include <iostream>
#include <string>
CoroutineFn<int> return_int() {
std::cout << ">>suspending to " << __func__ << '\n';
co_await qemu_coroutine_yield();
std::cout << ">>back\n";
co_return 30;
}
CoroutineFn<void> return_void() {
std::cout << ">>suspending to " << __func__ << '\n';
co_await qemu_coroutine_yield();
std::cout << ">>back\n";
}
CoroutineFn<void> co(void *) {
co_await return_void();
std::cout << co_await return_int() << '\n';
std::cout << "suspending\n";
co_await qemu_coroutine_yield();
std::cout << "back\n";
}
int main() {
auto f = qemu_coroutine_create(co, NULL);
std::cout << "--- 0\n";
qemu_coroutine_enter(f);
std::cout << "--- 1\n";
qemu_coroutine_enter(f);
std::cout << "--- 2\n";
qemu_coroutine_enter(f);
std::cout << "--- 3\n";
qemu_coroutine_enter(f);
std::cout << "--- 4\n";
}
#include "coro.h"
#include <iostream>
#include <string>
CoroutineFn<std::string> yield_and_resume(const char *s)
{
co_await qemu_coroutine_yield();
co_return s;
}
CoroutineFn<void> counter(void *opaque)
{
std::cout << co_await yield_and_resume("counter: resumed (#1)\n");
co_await qemu_coroutine_yield();
std::cout << co_await yield_and_resume("counter: resumed (#2)\n");
}
int main ()
{
std::cout << "main: calling counter\n";
Coroutine *the_counter = qemu_coroutine_create(counter, NULL);
qemu_coroutine_enter(the_counter); // in resumed()
qemu_coroutine_enter(the_counter);
qemu_coroutine_enter(the_counter); // in resumed()
qemu_coroutine_enter(the_counter);
std::cout << "main: done\n";
}
#include "coro.h"
#include <iostream>
#include <string>
CoroutineFn<std::string> resumed(const char *s)
{
co_return s;
}
CoroutineFn<void> counter(void *opaque)
{
std::cout << co_await resumed("counter: resumed (#1)\n");
co_await qemu_coroutine_yield();
std::cout << co_await resumed("counter: resumed (#2)\n");
}
int main ()
{
std::cout << "main: calling counter\n";
Coroutine *the_counter = qemu_coroutine_create(counter, NULL);
qemu_coroutine_enter(the_counter);
qemu_coroutine_enter(the_counter);
std::cout << "main: done\n";
}
#include "coro.h"
#include <iostream>
#include <string>
#include <vector>
CoroutineFn<std::vector<int>> resumed()
{
auto x = std::vector<int>();
x.push_back(1);
x.push_back(2);
x.push_back(3);
co_return x;
}
CoroutineFn<void> vec(void *opaque)
{
std::vector<int> v = co_await resumed();
std::cout << v.back() << '\n'; v.pop_back();
std::cout << v.back() << '\n'; v.pop_back();
std::cout << v.back() << '\n'; v.pop_back();
}
int main ()
{
Coroutine *co = qemu_coroutine_create(vec, NULL);
qemu_coroutine_enter(co);
std::cout << "main: done\n";
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment