Skip to content

Instantly share code, notes, and snippets.

@ttsuki
Last active August 20, 2023 21:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ttsuki/fd7106cc583bda02fa135a72bad7afcc to your computer and use it in GitHub Desktop.
Save ttsuki/fd7106cc583bda02fa135a72bad7afcc to your computer and use it in GitHub Desktop.
a study of C++ coroutine task system
// coro_task: A study of C++20 co-routine task system
#pragma once
#include <array>
#include <coroutine>
#include <utility>
#include <future>
#include <optional>
#include <stack>
#include <variant>
namespace coro::detail
{
template <class T>
class future
{
std::future<T> future_;
public:
using result_type = T;
using value_type = std::conditional_t<!std::is_void_v<result_type>, result_type, std::monostate>;
using optional_value_type = std::optional<value_type>;
future() = default;
future(std::future<result_type> future) : future_(std::move(future)) {}
future(const future& other) = delete;
future(future&& other) noexcept = default;
future& operator=(const future& other) = delete;
future& operator=(future&& other) noexcept = default;
~future() = default;
bool valid() const noexcept { return future_.valid(); }
bool ready() const noexcept { return future_.valid() && future_.wait_for(std::chrono::seconds::zero()) == std::future_status::ready; }
result_type get() { return future_.get(); }
std::optional<T> get_if_ready() requires(!std::is_void_v<T>) { return ready() ? std::optional<T>(get()) : std::optional<T>(std::nullopt); }
std::optional<std::monostate> get_if_ready() requires(std::is_void_v<T>) { return ready() ? (get(), std::optional<std::monostate>(std::in_place)) : std::optional<std::monostate>(std::nullopt); }
operator std::future<T>() && { return std::move(future_); }
};
template <class result_type>
class promise_base
{
std::promise<result_type> result_{};
public:
void return_value(const result_type& o) { result_.set_value(std::forward<decltype(o)>(o)); }
void return_value(result_type&& o) { result_.set_value(std::forward<decltype(o)>(o)); }
void return_value(auto&& o) { result_.set_value(std::forward<decltype(o)>(o)); }
void unhandled_exception() noexcept { result_.set_exception(std::current_exception()); }
future<result_type> get_future() { return result_.get_future(); }
};
template <>
class promise_base<void>
{
std::promise<void> result_{};
public:
void return_void() { result_.set_value(); }
void unhandled_exception() noexcept { result_.set_exception(std::current_exception()); }
future<void> get_future() { return result_.get_future(); }
};
class fiber_context final
{
std::stack<std::coroutine_handle<>> call_stack_{};
public:
class promise_base;
fiber_context() = default;
fiber_context(const fiber_context& other) = delete;
fiber_context(fiber_context&& other) noexcept = delete;
fiber_context& operator=(const fiber_context& other) = delete;
fiber_context& operator=(fiber_context&& other) noexcept = delete;
~fiber_context() { this->abort(); }
template <std::derived_from<promise_base> promise_type>
fiber_context(std::coroutine_handle<promise_type> main)
{
this->push(main);
this->suspend(main);
}
bool done() const noexcept
{
return call_stack_.empty();
}
void resume()
{
assert(!done());
if (done()) throw std::logic_error("invalid stack state");
call_stack_.top().resume();
}
void abort() noexcept
{
while (!call_stack_.empty())
{
call_stack_.top().destroy();
call_stack_.pop();
}
}
private:
auto suspend([[maybe_unused]] std::coroutine_handle<> current) noexcept
{
assert(call_stack_.top() == current);
}
template <std::derived_from<promise_base> promise_type>
auto push(std::coroutine_handle<promise_type> callee) noexcept
{
static_cast<promise_base&>(callee.promise()).context_ = this;
this->call_stack_.push(callee);
return callee;
}
auto pop() noexcept
{
auto callee = this->call_stack_.top();
this->call_stack_.pop();
auto caller = !this->call_stack_.empty()
? this->call_stack_.top()
: std::noop_coroutine();
// DO
// callee.destroy();
// return caller;
struct fire_and_forget final
{
std::coroutine_handle<> handle{};
struct promise_type
{
fire_and_forget get_return_object() { return fire_and_forget{std::coroutine_handle<promise_type>::from_promise(*this)}; }
auto initial_suspend() const noexcept { return std::suspend_always{}; }
auto final_suspend() const noexcept { return std::suspend_never{}; }
void return_void() const noexcept {}
void unhandled_exception() noexcept { std::terminate(); }
};
};
return [](std::coroutine_handle<> ee, std::coroutine_handle<> er) noexcept -> fire_and_forget
{
assert((ee.done()));
ee.destroy();
er.resume();
co_return;
}(callee, caller).handle;
}
public:
class promise_base
{
friend class fiber_context;
fiber_context* context_{};
public:
template <std::derived_from<promise_base> promise_type>
using initial_suspend_awaiter = std::suspend_always;
template <std::derived_from<promise_base> promise_type>
struct suspend_awaiter
{
bool await_ready() const noexcept { return false; }
auto await_suspend(std::coroutine_handle<promise_type> current) const noexcept { return static_cast<promise_base&>(current.promise()).context_->suspend(current); }
void await_resume() const noexcept {}
};
struct empty_future
{
constexpr void get() const noexcept {}
};
template <std::derived_from<promise_base> promise_type, std::derived_from<promise_base> callee_promise_type, class future_type = empty_future>
requires requires(future_type future) { future.get(); }
struct coroutine_call_awaiter
{
std::coroutine_handle<callee_promise_type> callee;
future_type future;
coroutine_call_awaiter(std::coroutine_handle<callee_promise_type> callee, future_type future = {}) : callee(callee), future(std::move(future)) {}
bool await_ready() const noexcept { return false; }
auto await_suspend(std::coroutine_handle<promise_type> caller) const noexcept { return static_cast<promise_base&>(caller.promise()).context_->push(callee); }
auto await_resume() { return future.get(); }
};
template <std::derived_from<promise_base> promise_type>
struct coroutine_return_awaiter
{
bool await_ready() const noexcept { return false; }
auto await_suspend(std::coroutine_handle<promise_type> callee) const noexcept { return static_cast<promise_base&>(callee.promise()).context_->pop(); }
void await_resume() const noexcept {}
};
};
};
struct suspend_this {};
template <class result_type>
class task final
{
public:
friend struct task_trait;
struct promise_type
: fiber_context::promise_base
, promise_base<result_type>
{
task get_return_object() noexcept { return std::coroutine_handle<promise_type>::from_promise(*this); }
auto initial_suspend() const noexcept { return initial_suspend_awaiter<promise_type>(); }
auto final_suspend() const noexcept { return coroutine_return_awaiter<promise_type>(); }
auto yield_value(suspend_this) const noexcept { return suspend_awaiter<promise_type>(); }
// for `co_await coro_task<T>(...);`
template <class callee_result_type>
auto await_transform(task<callee_result_type> callee_task)
{
using callee_promise_type = typename task<callee_result_type>::promise_type;
std::coroutine_handle<callee_promise_type> callee_handle = std::move(callee_task).detach_coroutine_handle();
if (!callee_handle)
throw std::logic_error(" empty task can't be co_await.");
return coroutine_call_awaiter<promise_type, callee_promise_type, future<callee_result_type>>{callee_handle, callee_handle.promise().get_future()};
}
template <class callee_result_type>
auto await_transform(future<callee_result_type> f)
{
if (!f.valid())
throw std::logic_error("invalid future can't be co_await.");
return this->await_transform([](future<callee_result_type> f) -> task<callee_result_type>
{
while (!f.ready()) co_yield{};
co_return f.get();
}(std::move(f)));
}
template <class callee_result_type>
auto await_transform(std::future<callee_result_type> f)
{
return this->await_transform(future(std::move(f)));
}
};
task() = default;
task(const task& other) = delete;
task(task&& other) noexcept : coroutine_handle_(std::move(other).detach_coroutine_handle()) {}
task& operator=(const task& other) = delete;
task& operator=(task&& other) noexcept { return (this != std::addressof(other)) ? this->reset(std::move(other).detach_coroutine_handle()) : *this; }
~task() { reset(); }
[[nodiscard]] std::coroutine_handle<promise_type> detach_coroutine_handle() && { return std::exchange(coroutine_handle_, nullptr); }
private:
std::coroutine_handle<promise_type> coroutine_handle_{};
task(std::coroutine_handle<promise_type> co_handle) : coroutine_handle_(co_handle) { }
task& reset(std::coroutine_handle<promise_type> h = nullptr) noexcept
{
if (coroutine_handle_) coroutine_handle_.destroy();
coroutine_handle_ = std::exchange(h, nullptr);
return *this;
}
};
template <class result_type>
class fiber final
{
future<result_type> future_;
fiber_context context_;
template <class task_promise_type>
fiber(std::coroutine_handle<task_promise_type> handle)
: future_(handle.promise().get_future())
, context_(handle) { }
public:
fiber(task<result_type> task) : fiber(std::move(task).detach_coroutine_handle()) {}
fiber(const fiber& other) = delete;
fiber(fiber&& other) noexcept = delete;
fiber& operator=(const fiber& other) = delete;
fiber& operator=(fiber&& other) noexcept = delete;
~fiber() = default;
void tick()
{
if (!context_.done())
context_.resume();
}
bool is_done() const { return context_.done(); }
bool is_ready() const { return future_.ready(); }
auto get_result() { return future_.get(); }
auto get_result_if_ready() { return future_.get_if_ready(); }
future<result_type> get_future() { return std::move(future_); }
};
template <class result_type>
fiber(task<result_type>) -> fiber<result_type>;
class runner
{
std::list<fiber_context> fibers_;
public:
runner() = default;
runner(const runner& other) = delete;
runner(runner&& other) noexcept = delete;
runner& operator=(const runner& other) = delete;
runner& operator=(runner&& other) noexcept = delete;
~runner() = default;
[[nodiscard]] bool empty() const { return fibers_.empty(); }
[[nodiscard]] size_t size() const { return fibers_.size(); }
template <class result_type>
[[nodiscard]] future<result_type> push_back(task<result_type> task)
{
auto handle = std::move(task).detach_coroutine_handle();
auto future = handle.promise().get_future();
fibers_.emplace_back(handle);
return future;
}
void tick()
{
for (auto it = fibers_.begin(); it != fibers_.end();)
{
if (it->resume(); it->done())
it = fibers_.erase(it);
else
++it;
}
}
};
}
namespace coro
{
using detail::task;
using detail::future;
using detail::fiber;
using detail::runner;
}
#include <cassert>
#include <iostream>
#include <string>
#include "coro_task.h"
static coro::task<int> CreateSubSubTask(std::string label, int wait)
{
int ret = wait;
std::cout << " " << label << ".SubSubTask: Created. wait = " << wait << std::endl;
while (wait--)
co_yield{};
std::cout << " " << label << ".SubSubTask: Completed." << std::endl;
co_return ret;
}
static coro::task<int> CreateSubTask(std::string label, int wait)
{
std::cout << " " << label << ".SubTask: Created. wait = " << wait << std::endl;
int ret = 0;
ret += co_await CreateSubSubTask(label, 0);
ret += co_await CreateSubSubTask(label, 1);
ret += co_await CreateSubSubTask(label, 2);
while (wait--)
{
std::cout << " " << label << ".SubTask: " << wait << std::endl;
co_yield {};
}
std::cout << " " << label << ".SubTask: Completed." << std::endl;
// waits for another thread: co_await std::future<int>
co_return co_await std::async(std::launch::async, [ret]() -> int
{
std::this_thread::sleep_for(std::chrono::seconds{2});
return ret * 2 + 36;
});
}
static coro::task<std::string> CreateTask(std::string label)
{
std::cout << label << ": " << "Created." << std::endl;
const auto sub_task_result = co_await CreateSubTask(label, 3);
for (int i = 0; i < 5; i++)
{
std::cout << " " << label << ".Body: " << i << std::endl;
co_yield {};
}
co_return label + " complete. result = " + std::to_string(sub_task_result);
}
int main()
{
std::cout << "Hello World!\n";
{
coro::fiber<std::string> fiber(CreateTask("FIBER"));
coro::runner runner;
coro::future<std::string> task0_result, task1_result, task2_result;
coro::future<void> task3_result;
for (int frame = 0; frame < 20; ++frame)
{
std::cout << "====" << "Frame " << frame << "====" << std::endl;
if (fiber.tick(); fiber.is_ready()) std::cout << fiber.get_result() << std::endl;
if (frame == 0) task0_result = runner.push_back(CreateTask("TASK 0"));
if (frame == 1) task1_result = runner.push_back(CreateTask("TASK 1"));
if (frame == 2) task2_result = runner.push_back(CreateTask("TASK 2"));
if (frame == 3)
task3_result = runner.push_back([]() -> coro::task<void> // lambda
{
const std::string result = co_await CreateTask("TASK 3");
std::cout << "TASK 3 Complete: " << result << std::endl;
}());
// tick all tasks.
runner.tick();
if (auto ret = task0_result.get_if_ready()) std::cout << "TASK 0 Complete: " << *ret << std::endl;
if (auto ret = task1_result.get_if_ready()) std::cout << "TASK 1 Complete: " << *ret << std::endl;
if (auto ret = task2_result.get_if_ready()) std::cout << "TASK 2 Complete: " << *ret << std::endl;
if (auto ret = task3_result.get_if_ready()) std::cout << "TASK 3 Complete: " << "(void)" << std::endl;
std::this_thread::sleep_for(std::chrono::seconds{1});
}
}
std::cout << "Good-bye World!\n";
}
Hello World!
====Frame 0====
FIBER: Created.
FIBER.SubTask: Created. wait = 3
FIBER.SubSubTask: Created. wait = 0
FIBER.SubSubTask: Completed.
FIBER.SubSubTask: Created. wait = 1
TASK 0: Created.
TASK 0.SubTask: Created. wait = 3
TASK 0.SubSubTask: Created. wait = 0
TASK 0.SubSubTask: Completed.
TASK 0.SubSubTask: Created. wait = 1
====Frame 1====
FIBER.SubSubTask: Completed.
FIBER.SubSubTask: Created. wait = 2
TASK 0.SubSubTask: Completed.
TASK 0.SubSubTask: Created. wait = 2
TASK 1: Created.
TASK 1.SubTask: Created. wait = 3
TASK 1.SubSubTask: Created. wait = 0
TASK 1.SubSubTask: Completed.
TASK 1.SubSubTask: Created. wait = 1
====Frame 2====
TASK 1.SubSubTask: Completed.
TASK 1.SubSubTask: Created. wait = 2
TASK 2: Created.
TASK 2.SubTask: Created. wait = 3
TASK 2.SubSubTask: Created. wait = 0
TASK 2.SubSubTask: Completed.
TASK 2.SubSubTask: Created. wait = 1
====Frame 3====
FIBER.SubSubTask: Completed.
FIBER.SubTask: 2
TASK 0.SubSubTask: Completed.
TASK 0.SubTask: 2
TASK 2.SubSubTask: Completed.
TASK 2.SubSubTask: Created. wait = 2
TASK 3: Created.
TASK 3.SubTask: Created. wait = 3
TASK 3.SubSubTask: Created. wait = 0
TASK 3.SubSubTask: Completed.
TASK 3.SubSubTask: Created. wait = 1
====Frame 4====
FIBER.SubTask: 1
TASK 0.SubTask: 1
TASK 1.SubSubTask: Completed.
TASK 1.SubTask: 2
TASK 3.SubSubTask: Completed.
TASK 3.SubSubTask: Created. wait = 2
====Frame 5====
FIBER.SubTask: 0
TASK 0.SubTask: 0
TASK 1.SubTask: 1
TASK 2.SubSubTask: Completed.
TASK 2.SubTask: 2
====Frame 6====
FIBER.SubTask: Completed.
TASK 0.SubTask: Completed.
TASK 1.SubTask: 0
TASK 2.SubTask: 1
TASK 3.SubSubTask: Completed.
TASK 3.SubTask: 2
====Frame 7====
TASK 1.SubTask: Completed.
TASK 2.SubTask: 0
TASK 3.SubTask: 1
====Frame 8====
FIBER.Body: 0
TASK 0.Body: 0
TASK 2.SubTask: Completed.
TASK 3.SubTask: 0
====Frame 9====
FIBER.Body: 1
TASK 0.Body: 1
TASK 1.Body: 0
TASK 3.SubTask: Completed.
====Frame 10====
FIBER.Body: 2
TASK 0.Body: 2
TASK 1.Body: 1
TASK 2.Body: 0
====Frame 11====
FIBER.Body: 3
TASK 0.Body: 3
TASK 1.Body: 2
TASK 2.Body: 1
TASK 3.Body: 0
====Frame 12====
FIBER.Body: 4
TASK 0.Body: 4
TASK 1.Body: 3
TASK 2.Body: 2
TASK 3.Body: 1
====Frame 13====
FIBER complete. result = 42
TASK 1.Body: 4
TASK 2.Body: 3
TASK 3.Body: 2
TASK 0 Complete: TASK 0 complete. result = 42
====Frame 14====
TASK 2.Body: 4
TASK 3.Body: 3
TASK 1 Complete: TASK 1 complete. result = 42
====Frame 15====
TASK 3.Body: 4
TASK 2 Complete: TASK 2 complete. result = 42
====Frame 16====
TASK 3 Complete: TASK 3 complete. result = 42
TASK 3 Complete: (void)
====Frame 17====
====Frame 18====
====Frame 19====
Good-bye World!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment