Skip to content

Instantly share code, notes, and snippets.

@Serikov
Last active April 4, 2023 12:12
Show Gist options
  • Save Serikov/b28115e3b13a7c0ec45ab76468ddb0bd to your computer and use it in GitHub Desktop.
Save Serikov/b28115e3b13a7c0ec45ab76468ddb0bd to your computer and use it in GitHub Desktop.
C++ Coroutines Ts generator<T> with co_await
#include <experimental/coroutine>
#include <stdexcept>
#include <variant>
namespace detail {
// simple type erasure for iterators
template<typename T>
struct generic_iterable
{
virtual ~generic_iterable() = default;
virtual T& operator*() = 0;
virtual generic_iterable<T>& operator++() = 0;
virtual bool empty() const = 0;
bool await_ready() const noexcept
{
return empty();
}
template<typename U>
void await_suspend(std::experimental::coroutine_handle<U> h) noexcept
{
h.promise().store_iterator(this);
}
void await_resume() const noexcept {}
};
template<typename T, typename StartIterator, typename EndIterator>
struct iterator_iterable : public generic_iterable<T>
{
iterator_iterable(StartIterator start, EndIterator end) : start(start), end(end) {}
T& operator*() override
{
return *start;
}
iterator_iterable<T, StartIterator, EndIterator>& operator++() override
{
++start;
return *this;
}
bool empty() const override
{
return start == end;
}
StartIterator start;
EndIterator end;
};
} // namespace detail
template<typename T>
struct generator
{
using value_type = T;
struct promise_type
{
// 0: prestart, 1: value, 2: range, 3: done
std::variant<std::monostate, T*, detail::generic_iterable<T>*, std::monostate> state;
promise_type& get_return_object() noexcept
{
return *this;
}
std::experimental::suspend_always initial_suspend() const noexcept
{
return {};
}
std::experimental::suspend_always final_suspend() const noexcept
{
return {};
}
std::experimental::suspend_always yield_value(T& value) noexcept
{
state.template emplace<1>(std::addressof(value));
return {};
}
std::experimental::suspend_always yield_value(T&& value) noexcept
{
state.template emplace<1>(std::addressof(value));
return {};
}
void return_void() noexcept
{
state.template emplace<3>();
}
template<typename Range>
auto await_transform(Range&& range) const noexcept
{
using std::begin;
using std::end;
auto s = begin(range);
auto e = end(range);
// TODO: properly constraint
static_assert(std::is_same_v<decltype(*s), T&>);
detail::iterator_iterable<T, decltype(s), decltype(e)> iterator{s, e};
return iterator;
}
void unhandled_exception()
{
state.template emplace<3>();
auto ex = std::current_exception();
std::rethrow_exception(ex);
//// MSVC bug? should be possible to rethrow with "throw;"
//// rethrow exception immediately
// throw;
}
void store_iterator(detail::generic_iterable<T>* iterator) noexcept
{
state.template emplace<2>(iterator);
}
T& value()
{
switch (state.index()) {
case 1:
return *std::get<1>(state);
case 2:
return **std::get<2>(state);
case 0:
next();
return value();
default:
case 3:
throw std::logic_error("Generator already completed!");
}
}
const T& value() const
{
switch (state.index()) {
case 1:
return *std::get<1>(state);
case 2:
return **std::get<2>(state);
case 0:
next();
return value();
default:
case 3:
throw std::logic_error("Generator already completed!");
}
}
void next()
{
auto handle = std::experimental::coroutine_handle<promise_type>::from_promise(*this);
switch (state.index()) {
case 0:
case 1:
handle.resume();
break;
case 2: {
auto& iterator = *std::get<2>(state);
++iterator;
if (iterator.empty()) {
state.template emplace<0>();
handle.resume();
}
break;
}
default:
case 3:
throw std::logic_error("Generator already completed!");
}
}
};
using handle_type = std::experimental::coroutine_handle<promise_type>;
struct iterator
{
using iterator_category = std::input_iterator_tag;
using value_type = T;
using difference_type = ptrdiff_t;
using pointer = T*;
using reference = T&;
handle_type coro_handle;
iterator() : coro_handle(nullptr) {}
iterator(handle_type coro_handle) : coro_handle(coro_handle) {}
iterator& operator++()
{
try {
coro_handle.promise().next();
} catch (...) {
coro_handle = nullptr;
throw;
}
if (coro_handle.done())
coro_handle = nullptr;
return *this;
}
iterator operator++(int) = delete;
bool operator==(iterator const& other) const
{
return coro_handle == other.coro_handle;
}
bool operator!=(iterator const& other) const
{
return !(*this == other);
}
const T& operator*() const
{
return coro_handle.promise().value();
}
const T* operator->() const
{
return std::addressof(operator*());
}
T& operator*()
{
return coro_handle.promise().value();
}
T* operator->()
{
return std::addressof(operator*());
}
};
iterator begin()
{
if (coro_handle) {
if (coro_handle.done())
return {};
}
return {coro_handle};
}
iterator end()
{
return {};
}
generator(promise_type& promise) : coro_handle(handle_type::from_promise(promise)) {}
generator() = default;
generator(generator const&) = delete;
generator& operator=(generator const&) = delete;
generator(generator&& other) : coro_handle(other.coro_handle)
{
other.coro_handle = nullptr;
}
generator& operator=(generator&& other)
{
if (&other != this) {
coro_handle = other.coro_handle;
other.coro_handle = nullptr;
}
return *this;
}
~generator()
{
if (coro_handle) {
coro_handle.destroy();
}
}
private:
std::experimental::coroutine_handle<promise_type> coro_handle = nullptr;
};
template<typename T>
generator<int> range(T first, T last)
{
while (first != last) {
co_yield first++;
}
}
template<typename T>
generator<int> range1(T first, T last)
{
while (first != last) {
throw std::logic_error("BEGIN");
co_yield first++;
}
}
template<typename T>
generator<int> range2(T first, T last)
{
while (first != last) {
co_yield first++;
throw std::logic_error("ITERATOR");
}
}
template<typename T>
generator<int> range4(T first, T last)
{
co_return;
}
template<typename T>
generator<int> range5_ex(T first, T last)
{
while (first != last) {
co_yield first++;
}
throw std::logic_error("AFTER LAST YIELD (for example cleanup failure)");
}
/////////////////////////
generator<std::unique_ptr<int>> range_unqptr(int first, int last)
{
while (first != last) {
co_yield std::make_unique<int>(first++);
}
}
generator<std::unique_ptr<int>> gen_refs()
{
auto value = std::make_unique<int>(1);
co_yield value; // lvalue&, can be moved from
co_yield std::make_unique<int>(2); // rvalue&, can be moved from
const auto cvalue = std::make_unique<int>(3);
// co_yield cvalue; // compile error
}
generator<const std::unique_ptr<int>> gen_refs2()
{
auto value = std::make_unique<int>(1);
co_yield value; // lvalue&
co_yield std::make_unique<int>(2); // rvalue&
const auto cvalue = std::make_unique<int>(3);
co_yield cvalue; // no compile error
}
/////////////////////////
// combining generators and ranges
generator<int> gen_and_then(generator<int> first, generator<int> second)
{
co_await first;
co_yield 999;
co_await second;
co_yield - 999;
}
template<typename Range1, typename Range2>
generator<int> and_then(Range1 first, Range2 second)
{
co_await first;
co_yield 999;
co_await second;
co_yield - 999;
}
#include <stdio.h>
#include <string>
#include <vector>
int main()
{
// tests 1
printf("\nrange\n");
try {
for (int i : range(0, 10)) {
printf("%d\n", i);
}
} catch (std::exception const& e) {
printf("%s\n", e.what());
}
printf("\nrange1\n");
try {
for (int i : range1(0, 10)) {
printf("%d\n", i);
}
} catch (std::exception const& e) {
printf("%s\n", e.what());
}
printf("\nrange2\n");
try {
for (int i : range2(0, 10)) {
printf("%d\n", i);
}
} catch (std::exception const& e) {
printf("%s\n", e.what());
}
try {
for (int i : range4(0, 10)) {
printf("%d\n", i);
}
} catch (std::exception const& e) {
printf("%s\n", e.what());
}
printf("\nrange5_ex\n");
try {
for (int i : range5_ex(0, 10)) {
printf("%d\n", i);
}
} catch (std::exception const& e) {
printf("%s\n", e.what());
}
// tests 2
try {
printf("\nrange_unqptr\n");
for (auto& i : range_unqptr(0, 10)) {
auto b = std::move(i); // can be moved from, no UB
printf("%d\n", *b);
}
printf("\ngen_refs\n");
for (auto& i : gen_refs()) {
auto b = std::move(i); // can be moved from, no UB
printf("%d\n", *b);
}
printf("\ngen_refs2\n");
for (auto& i : gen_refs2()) {
// auto b = std::move(i); // compile time error
printf("%d\n", *i);
}
printf("\ngen_and_then\n");
auto g1 = gen_and_then(range(0, 2), range(5, 10));
for (auto i : g1) {
printf("%d\n", i);
}
printf("\ngen_and_then\n");
auto g2 = gen_and_then(range(0, 2), range(5, 10));
for (auto i : g2) {
printf("%d\n", i);
}
printf("\nand_then gen+gen\n");
auto g3 = and_then(range(0, 2), range(5, 10));
for (auto i : g3) {
printf("%d\n", i);
}
printf("\nand_then vec+vec\n");
auto g4 = and_then(std::vector<int>{0, 1, 2}, std::vector<int>{5, 6, 7});
for (auto i : g4) {
printf("%d\n", i);
}
printf("\nand_then vec+gen\n");
auto g5 = and_then(std::vector<int>{0, 1, 2}, range(5, 10));
for (auto i : g5) {
printf("%d\n", i);
}
// auto g6 = and_then(std::string("Hello world"), range(5, 10)); // compile error
auto lamda_gen = []() -> generator<const char> {
co_yield '\n';
co_await "Hello";
co_yield ' ';
co_await "generator!";
co_yield '\n';
};
for (auto c : lamda_gen()) {
printf("%c", c);
}
} catch (std::exception const& e) {
printf("%s\n", e.what());
}
try {
auto gen = range(5, 10);
for (auto i : gen) {
}
for (auto i : gen) {
printf("\nERROR: should be noop!\n");
}
} catch (std::exception const& e) {
printf("Error: %s\n", e.what());
}
// Calling begin multiple times
try {
auto gen = range(5, 8);
printf("\nShould print 5 6 7!\n");
gen.begin();
gen.begin();
gen.begin();
gen.begin();
for (auto i : gen) {
printf("%d\n", i);
}
} catch (std::exception const& e) {
printf("Error: %s\n", e.what());
}
}
range
0
1
2
3
4
5
6
7
8
9
range1
BEGIN
range2
0
ITERATOR
Generator already completed!
range5_ex
0
1
2
3
4
5
6
7
8
9
AFTER LAST YIELD (for example cleanup failure)
range_unqptr
0
1
2
3
4
5
6
7
8
9
gen_refs
1
2
gen_refs2
1
2
3
gen_and_then
0
1
999
5
6
7
8
9
-999
gen_and_then
0
1
999
5
6
7
8
9
-999
and_then gen+gen
0
1
999
5
6
7
8
9
-999
and_then vec+vec
0
1
2
999
5
6
7
-999
and_then vec+gen
0
1
2
999
5
6
7
8
9
-999
Hello generator!
Should print 5 6 7!
5
6
7
@getsoubl
Copy link

Reading the theory behind the co_yield when co_yield is called the compiler will produce the following piece of code

template
generator range(T first, T last)
{
__counter_context* __context = new __counter_context{};
__return = __context‐>_promise.get_return_object();
co_await __context‐>_promise.initial_suspend();
while (first != last) {
co_wait promise.yield_value(first++);
}
__final_suspend_label:
co_await __context‐>_promise.final_suspend();
}
co_wait promise.yield_value(first++) is equivalent to

auto&& __awaitable = y;
if (__awaitable.await_ready())
{
__awaitable.await_suspend();
// ...suspend/resume point...
}
__awaitable.await_resume();
My first question is about initial_suspend. Initial_suspend returns a suspend_always object. If I understand correctly in this point the cor-routine waits of a resume signal and returns the control to the caller. Who wakes up the corroutine in this point? Why do not use suspend_never?

The second point that confuses me is how the integer values first and last are converted to iterator objects and the operator ++ is called when the first++ is executed.
Moreover I cannot understand how the begin and end function are called .

I try to debug this step by step but I was not able to verify the processing steps.

For me the code it is clear until the point in which the get_return_object is called.

@Serikov
Copy link
Author

Serikov commented Apr 18, 2020

This generator type has two parts:

  1. generator::promise_type that is responsible for the coroutine logic - awaiting, yielding and other.
  2. generator object - actual return value of the coroutine. It is responsible for getting values out of coroutine.
    This second object is non-copyable, movable object that has begin() and end() member functions. Through this member functions user can get values out of the generator. For example by using range-for loop:
    for (int i : range(0, 10)) {
        printf("%d\n", i);
    }

The return value of range(0, 10) is actually a generator<int> which is used in range-for loop.

Who wakes up the corroutine in this point?

When generator::iterator operator++ is called it calls next() member function on the promise object. Inside that function handle.resume resumes coroutine.

Why do not use suspend_never

If suspend_never was used the coroutine would be started even when the generator was never used and discarded immediately:

generator<int> failed_generator()
{
    // throw always
    if (true)
        throw std::logic_error("Fail!!!");

    co_yield 1;
}
int main()
{
    auto f = failed_generator();
    // f is not used => no exception
}

In this generator implementation the coroutine is not run until the first value from it is needed.

The second point that confuses me is how the integer values first and last are converted to iterator objects and the operator ++ is called when the first++ is executed.

The integer values first and last are never converted to iterator objects. This integer values are stored inside the coroutine frame. generator object create it's iterators in begin() and end functions:

    iterator begin()
    {
        if (coro_handle) {
            if (coro_handle.done())
                return {}; // << empty iterator created HERE
        }

        return { coro_handle }; // << normal iterator created HERE
    }

    iterator end()
    {
        return {}; // << empty iterator created HERE
    }

When normal iterator is advanced or dereferenced it resumes the coroutine. When coroutine reaches it's end and handle.done() returns true that normal iterator becomes empty iterator by discarding coroutine handle. Empty iterators is equal to each other so this now empy iterator become equal to end() iterator of the generator.

So generator and it's iterators never know about arguments of the coroutine or it's inner structure (loops and so on). All that generator's iterators is care about is whether the coroutine is finished or not yet. More so coroutine can have multiple loops or no loops at all - it does not matter to generator and it's iterators.

Moreover I cannot understand how the begin and end function are called .

begin and end is called by the user of the generator. In my examples it is called by range-for.

It can be used directly with:

    generator<int> gen = range(0, 10);

    for (auto it = gen.begin(); it != gen.end(); ++it) {
        printf("%d\n", *it);
    }

@getsoubl
Copy link

getsoubl commented Apr 20, 2020

Thank you for the detailed analysis. I had the impression that range(0,10) is the initial call.
One additional point that confuses me is how the loop finishes?
How the condition (coro_handle.done()) is evaluated to true and loop exits?
I cannot understand how to done() function works

@Serikov
Copy link
Author

Serikov commented Apr 20, 2020

I had the impression that range(0,10) is the initial call.

Yes, range(0, 10) is the initial call to coroutine. It creates promise object, coroutine frame, stores arguments and then when initial_suspend() returns suspend_always - suspend coroutine. After that coroutine is resumed through coroutine handle.

How the condition (coro_handle.done()) is evaluated to true and loop exits?
I cannot understand how to done() function works

coro_handle.done() returns whether coroutine function has reached its end. It does not know about loop inside coroutine only whether function reached return statement (or fell off from last statement) or not. This function is implemented by the compiler.

For example coroutine body can have multiple loops inside body or no loops at all.

@MANIKTANEJA3
Copy link

There are very limited resources on the coroutines online could you recommend any so that i can have a deeper understanding of the same?

@Serikov
Copy link
Author

Serikov commented May 2, 2020

There are very limited resources on the coroutines online could you recommend any so that i can have a deeper understanding of the same?

There were a couple of blog posts about coroutines but unfortunately I found nothing worth recommending.

Here is the list of materials that I used:

Lectures from Gor Nishanov:
CppCon 2014: Gor Nishanov "await 2.0: Stackless Resumable Functions"
CppCon 2015: Gor Nishanov “C++ Coroutines - a negative overhead abstraction"

First paper about corutines with Gor Nishanov as a co-author: Resumable Functions (revision 3)

Multiple revisions of Coroutine TS paper. It contains motivation, examples and description of the coroutine feature:
Working Draft, C++Extensions for Coroutines - 2017-07-30
Working Draft, C++Extensions for Coroutines - 2018-02-11
Working Draft, C++Extensions for Coroutines - 2018-06-24

Additions to coroutines with motivation and examples:
Add parameter preview to coroutine promise constructor
Add symmetric coroutine control transfer

Resulting C++ standard:
Working Draft, Standard for Programming Language C++

@MANIKTANEJA3
Copy link

MANIKTANEJA3 commented May 2, 2020 via email

@njikmf
Copy link

njikmf commented Dec 23, 2022

Thanks!!! Workable with latest MSVC and easy to learn.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment