Last active
August 20, 2023 21:47
-
-
Save ttsuki/fd7106cc583bda02fa135a72bad7afcc to your computer and use it in GitHub Desktop.
a study of C++ coroutine task system
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// coro_task: A study of C++20 co-routine task system | |
#pragma once | |
#include <array> | |
#include <coroutine> | |
#include <utility> | |
#include <future> | |
#include <optional> | |
#include <stack> | |
#include <variant> | |
namespace coro::detail | |
{ | |
template <class T> | |
class future | |
{ | |
std::future<T> future_; | |
public: | |
using result_type = T; | |
using value_type = std::conditional_t<!std::is_void_v<result_type>, result_type, std::monostate>; | |
using optional_value_type = std::optional<value_type>; | |
future() = default; | |
future(std::future<result_type> future) : future_(std::move(future)) {} | |
future(const future& other) = delete; | |
future(future&& other) noexcept = default; | |
future& operator=(const future& other) = delete; | |
future& operator=(future&& other) noexcept = default; | |
~future() = default; | |
bool valid() const noexcept { return future_.valid(); } | |
bool ready() const noexcept { return future_.valid() && future_.wait_for(std::chrono::seconds::zero()) == std::future_status::ready; } | |
result_type get() { return future_.get(); } | |
std::optional<T> get_if_ready() requires(!std::is_void_v<T>) { return ready() ? std::optional<T>(get()) : std::optional<T>(std::nullopt); } | |
std::optional<std::monostate> get_if_ready() requires(std::is_void_v<T>) { return ready() ? (get(), std::optional<std::monostate>(std::in_place)) : std::optional<std::monostate>(std::nullopt); } | |
operator std::future<T>() && { return std::move(future_); } | |
}; | |
template <class result_type> | |
class promise_base | |
{ | |
std::promise<result_type> result_{}; | |
public: | |
void return_value(const result_type& o) { result_.set_value(std::forward<decltype(o)>(o)); } | |
void return_value(result_type&& o) { result_.set_value(std::forward<decltype(o)>(o)); } | |
void return_value(auto&& o) { result_.set_value(std::forward<decltype(o)>(o)); } | |
void unhandled_exception() noexcept { result_.set_exception(std::current_exception()); } | |
future<result_type> get_future() { return result_.get_future(); } | |
}; | |
template <> | |
class promise_base<void> | |
{ | |
std::promise<void> result_{}; | |
public: | |
void return_void() { result_.set_value(); } | |
void unhandled_exception() noexcept { result_.set_exception(std::current_exception()); } | |
future<void> get_future() { return result_.get_future(); } | |
}; | |
class fiber_context final | |
{ | |
std::stack<std::coroutine_handle<>> call_stack_{}; | |
public: | |
class promise_base; | |
fiber_context() = default; | |
fiber_context(const fiber_context& other) = delete; | |
fiber_context(fiber_context&& other) noexcept = delete; | |
fiber_context& operator=(const fiber_context& other) = delete; | |
fiber_context& operator=(fiber_context&& other) noexcept = delete; | |
~fiber_context() { this->abort(); } | |
template <std::derived_from<promise_base> promise_type> | |
fiber_context(std::coroutine_handle<promise_type> main) | |
{ | |
this->push(main); | |
this->suspend(main); | |
} | |
bool done() const noexcept | |
{ | |
return call_stack_.empty(); | |
} | |
void resume() | |
{ | |
assert(!done()); | |
if (done()) throw std::logic_error("invalid stack state"); | |
call_stack_.top().resume(); | |
} | |
void abort() noexcept | |
{ | |
while (!call_stack_.empty()) | |
{ | |
call_stack_.top().destroy(); | |
call_stack_.pop(); | |
} | |
} | |
private: | |
auto suspend([[maybe_unused]] std::coroutine_handle<> current) noexcept | |
{ | |
assert(call_stack_.top() == current); | |
} | |
template <std::derived_from<promise_base> promise_type> | |
auto push(std::coroutine_handle<promise_type> callee) noexcept | |
{ | |
static_cast<promise_base&>(callee.promise()).context_ = this; | |
this->call_stack_.push(callee); | |
return callee; | |
} | |
auto pop() noexcept | |
{ | |
auto callee = this->call_stack_.top(); | |
this->call_stack_.pop(); | |
auto caller = !this->call_stack_.empty() | |
? this->call_stack_.top() | |
: std::noop_coroutine(); | |
// DO | |
// callee.destroy(); | |
// return caller; | |
struct fire_and_forget final | |
{ | |
std::coroutine_handle<> handle{}; | |
struct promise_type | |
{ | |
fire_and_forget get_return_object() { return fire_and_forget{std::coroutine_handle<promise_type>::from_promise(*this)}; } | |
auto initial_suspend() const noexcept { return std::suspend_always{}; } | |
auto final_suspend() const noexcept { return std::suspend_never{}; } | |
void return_void() const noexcept {} | |
void unhandled_exception() noexcept { std::terminate(); } | |
}; | |
}; | |
return [](std::coroutine_handle<> ee, std::coroutine_handle<> er) noexcept -> fire_and_forget | |
{ | |
assert((ee.done())); | |
ee.destroy(); | |
er.resume(); | |
co_return; | |
}(callee, caller).handle; | |
} | |
public: | |
class promise_base | |
{ | |
friend class fiber_context; | |
fiber_context* context_{}; | |
public: | |
template <std::derived_from<promise_base> promise_type> | |
using initial_suspend_awaiter = std::suspend_always; | |
template <std::derived_from<promise_base> promise_type> | |
struct suspend_awaiter | |
{ | |
bool await_ready() const noexcept { return false; } | |
auto await_suspend(std::coroutine_handle<promise_type> current) const noexcept { return static_cast<promise_base&>(current.promise()).context_->suspend(current); } | |
void await_resume() const noexcept {} | |
}; | |
struct empty_future | |
{ | |
constexpr void get() const noexcept {} | |
}; | |
template <std::derived_from<promise_base> promise_type, std::derived_from<promise_base> callee_promise_type, class future_type = empty_future> | |
requires requires(future_type future) { future.get(); } | |
struct coroutine_call_awaiter | |
{ | |
std::coroutine_handle<callee_promise_type> callee; | |
future_type future; | |
coroutine_call_awaiter(std::coroutine_handle<callee_promise_type> callee, future_type future = {}) : callee(callee), future(std::move(future)) {} | |
bool await_ready() const noexcept { return false; } | |
auto await_suspend(std::coroutine_handle<promise_type> caller) const noexcept { return static_cast<promise_base&>(caller.promise()).context_->push(callee); } | |
auto await_resume() { return future.get(); } | |
}; | |
template <std::derived_from<promise_base> promise_type> | |
struct coroutine_return_awaiter | |
{ | |
bool await_ready() const noexcept { return false; } | |
auto await_suspend(std::coroutine_handle<promise_type> callee) const noexcept { return static_cast<promise_base&>(callee.promise()).context_->pop(); } | |
void await_resume() const noexcept {} | |
}; | |
}; | |
}; | |
struct suspend_this {}; | |
template <class result_type> | |
class task final | |
{ | |
public: | |
friend struct task_trait; | |
struct promise_type | |
: fiber_context::promise_base | |
, promise_base<result_type> | |
{ | |
task get_return_object() noexcept { return std::coroutine_handle<promise_type>::from_promise(*this); } | |
auto initial_suspend() const noexcept { return initial_suspend_awaiter<promise_type>(); } | |
auto final_suspend() const noexcept { return coroutine_return_awaiter<promise_type>(); } | |
auto yield_value(suspend_this) const noexcept { return suspend_awaiter<promise_type>(); } | |
// for `co_await coro_task<T>(...);` | |
template <class callee_result_type> | |
auto await_transform(task<callee_result_type> callee_task) | |
{ | |
using callee_promise_type = typename task<callee_result_type>::promise_type; | |
std::coroutine_handle<callee_promise_type> callee_handle = std::move(callee_task).detach_coroutine_handle(); | |
if (!callee_handle) | |
throw std::logic_error(" empty task can't be co_await."); | |
return coroutine_call_awaiter<promise_type, callee_promise_type, future<callee_result_type>>{callee_handle, callee_handle.promise().get_future()}; | |
} | |
template <class callee_result_type> | |
auto await_transform(future<callee_result_type> f) | |
{ | |
if (!f.valid()) | |
throw std::logic_error("invalid future can't be co_await."); | |
return this->await_transform([](future<callee_result_type> f) -> task<callee_result_type> | |
{ | |
while (!f.ready()) co_yield{}; | |
co_return f.get(); | |
}(std::move(f))); | |
} | |
template <class callee_result_type> | |
auto await_transform(std::future<callee_result_type> f) | |
{ | |
return this->await_transform(future(std::move(f))); | |
} | |
}; | |
task() = default; | |
task(const task& other) = delete; | |
task(task&& other) noexcept : coroutine_handle_(std::move(other).detach_coroutine_handle()) {} | |
task& operator=(const task& other) = delete; | |
task& operator=(task&& other) noexcept { return (this != std::addressof(other)) ? this->reset(std::move(other).detach_coroutine_handle()) : *this; } | |
~task() { reset(); } | |
[[nodiscard]] std::coroutine_handle<promise_type> detach_coroutine_handle() && { return std::exchange(coroutine_handle_, nullptr); } | |
private: | |
std::coroutine_handle<promise_type> coroutine_handle_{}; | |
task(std::coroutine_handle<promise_type> co_handle) : coroutine_handle_(co_handle) { } | |
task& reset(std::coroutine_handle<promise_type> h = nullptr) noexcept | |
{ | |
if (coroutine_handle_) coroutine_handle_.destroy(); | |
coroutine_handle_ = std::exchange(h, nullptr); | |
return *this; | |
} | |
}; | |
template <class result_type> | |
class fiber final | |
{ | |
future<result_type> future_; | |
fiber_context context_; | |
template <class task_promise_type> | |
fiber(std::coroutine_handle<task_promise_type> handle) | |
: future_(handle.promise().get_future()) | |
, context_(handle) { } | |
public: | |
fiber(task<result_type> task) : fiber(std::move(task).detach_coroutine_handle()) {} | |
fiber(const fiber& other) = delete; | |
fiber(fiber&& other) noexcept = delete; | |
fiber& operator=(const fiber& other) = delete; | |
fiber& operator=(fiber&& other) noexcept = delete; | |
~fiber() = default; | |
void tick() | |
{ | |
if (!context_.done()) | |
context_.resume(); | |
} | |
bool is_done() const { return context_.done(); } | |
bool is_ready() const { return future_.ready(); } | |
auto get_result() { return future_.get(); } | |
auto get_result_if_ready() { return future_.get_if_ready(); } | |
future<result_type> get_future() { return std::move(future_); } | |
}; | |
template <class result_type> | |
fiber(task<result_type>) -> fiber<result_type>; | |
class runner | |
{ | |
std::list<fiber_context> fibers_; | |
public: | |
runner() = default; | |
runner(const runner& other) = delete; | |
runner(runner&& other) noexcept = delete; | |
runner& operator=(const runner& other) = delete; | |
runner& operator=(runner&& other) noexcept = delete; | |
~runner() = default; | |
[[nodiscard]] bool empty() const { return fibers_.empty(); } | |
[[nodiscard]] size_t size() const { return fibers_.size(); } | |
template <class result_type> | |
[[nodiscard]] future<result_type> push_back(task<result_type> task) | |
{ | |
auto handle = std::move(task).detach_coroutine_handle(); | |
auto future = handle.promise().get_future(); | |
fibers_.emplace_back(handle); | |
return future; | |
} | |
void tick() | |
{ | |
for (auto it = fibers_.begin(); it != fibers_.end();) | |
{ | |
if (it->resume(); it->done()) | |
it = fibers_.erase(it); | |
else | |
++it; | |
} | |
} | |
}; | |
} | |
namespace coro | |
{ | |
using detail::task; | |
using detail::future; | |
using detail::fiber; | |
using detail::runner; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cassert> | |
#include <iostream> | |
#include <string> | |
#include "coro_task.h" | |
static coro::task<int> CreateSubSubTask(std::string label, int wait) | |
{ | |
int ret = wait; | |
std::cout << " " << label << ".SubSubTask: Created. wait = " << wait << std::endl; | |
while (wait--) | |
co_yield{}; | |
std::cout << " " << label << ".SubSubTask: Completed." << std::endl; | |
co_return ret; | |
} | |
static coro::task<int> CreateSubTask(std::string label, int wait) | |
{ | |
std::cout << " " << label << ".SubTask: Created. wait = " << wait << std::endl; | |
int ret = 0; | |
ret += co_await CreateSubSubTask(label, 0); | |
ret += co_await CreateSubSubTask(label, 1); | |
ret += co_await CreateSubSubTask(label, 2); | |
while (wait--) | |
{ | |
std::cout << " " << label << ".SubTask: " << wait << std::endl; | |
co_yield {}; | |
} | |
std::cout << " " << label << ".SubTask: Completed." << std::endl; | |
// waits for another thread: co_await std::future<int> | |
co_return co_await std::async(std::launch::async, [ret]() -> int | |
{ | |
std::this_thread::sleep_for(std::chrono::seconds{2}); | |
return ret * 2 + 36; | |
}); | |
} | |
static coro::task<std::string> CreateTask(std::string label) | |
{ | |
std::cout << label << ": " << "Created." << std::endl; | |
const auto sub_task_result = co_await CreateSubTask(label, 3); | |
for (int i = 0; i < 5; i++) | |
{ | |
std::cout << " " << label << ".Body: " << i << std::endl; | |
co_yield {}; | |
} | |
co_return label + " complete. result = " + std::to_string(sub_task_result); | |
} | |
int main() | |
{ | |
std::cout << "Hello World!\n"; | |
{ | |
coro::fiber<std::string> fiber(CreateTask("FIBER")); | |
coro::runner runner; | |
coro::future<std::string> task0_result, task1_result, task2_result; | |
coro::future<void> task3_result; | |
for (int frame = 0; frame < 20; ++frame) | |
{ | |
std::cout << "====" << "Frame " << frame << "====" << std::endl; | |
if (fiber.tick(); fiber.is_ready()) std::cout << fiber.get_result() << std::endl; | |
if (frame == 0) task0_result = runner.push_back(CreateTask("TASK 0")); | |
if (frame == 1) task1_result = runner.push_back(CreateTask("TASK 1")); | |
if (frame == 2) task2_result = runner.push_back(CreateTask("TASK 2")); | |
if (frame == 3) | |
task3_result = runner.push_back([]() -> coro::task<void> // lambda | |
{ | |
const std::string result = co_await CreateTask("TASK 3"); | |
std::cout << "TASK 3 Complete: " << result << std::endl; | |
}()); | |
// tick all tasks. | |
runner.tick(); | |
if (auto ret = task0_result.get_if_ready()) std::cout << "TASK 0 Complete: " << *ret << std::endl; | |
if (auto ret = task1_result.get_if_ready()) std::cout << "TASK 1 Complete: " << *ret << std::endl; | |
if (auto ret = task2_result.get_if_ready()) std::cout << "TASK 2 Complete: " << *ret << std::endl; | |
if (auto ret = task3_result.get_if_ready()) std::cout << "TASK 3 Complete: " << "(void)" << std::endl; | |
std::this_thread::sleep_for(std::chrono::seconds{1}); | |
} | |
} | |
std::cout << "Good-bye World!\n"; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Hello World! | |
====Frame 0==== | |
FIBER: Created. | |
FIBER.SubTask: Created. wait = 3 | |
FIBER.SubSubTask: Created. wait = 0 | |
FIBER.SubSubTask: Completed. | |
FIBER.SubSubTask: Created. wait = 1 | |
TASK 0: Created. | |
TASK 0.SubTask: Created. wait = 3 | |
TASK 0.SubSubTask: Created. wait = 0 | |
TASK 0.SubSubTask: Completed. | |
TASK 0.SubSubTask: Created. wait = 1 | |
====Frame 1==== | |
FIBER.SubSubTask: Completed. | |
FIBER.SubSubTask: Created. wait = 2 | |
TASK 0.SubSubTask: Completed. | |
TASK 0.SubSubTask: Created. wait = 2 | |
TASK 1: Created. | |
TASK 1.SubTask: Created. wait = 3 | |
TASK 1.SubSubTask: Created. wait = 0 | |
TASK 1.SubSubTask: Completed. | |
TASK 1.SubSubTask: Created. wait = 1 | |
====Frame 2==== | |
TASK 1.SubSubTask: Completed. | |
TASK 1.SubSubTask: Created. wait = 2 | |
TASK 2: Created. | |
TASK 2.SubTask: Created. wait = 3 | |
TASK 2.SubSubTask: Created. wait = 0 | |
TASK 2.SubSubTask: Completed. | |
TASK 2.SubSubTask: Created. wait = 1 | |
====Frame 3==== | |
FIBER.SubSubTask: Completed. | |
FIBER.SubTask: 2 | |
TASK 0.SubSubTask: Completed. | |
TASK 0.SubTask: 2 | |
TASK 2.SubSubTask: Completed. | |
TASK 2.SubSubTask: Created. wait = 2 | |
TASK 3: Created. | |
TASK 3.SubTask: Created. wait = 3 | |
TASK 3.SubSubTask: Created. wait = 0 | |
TASK 3.SubSubTask: Completed. | |
TASK 3.SubSubTask: Created. wait = 1 | |
====Frame 4==== | |
FIBER.SubTask: 1 | |
TASK 0.SubTask: 1 | |
TASK 1.SubSubTask: Completed. | |
TASK 1.SubTask: 2 | |
TASK 3.SubSubTask: Completed. | |
TASK 3.SubSubTask: Created. wait = 2 | |
====Frame 5==== | |
FIBER.SubTask: 0 | |
TASK 0.SubTask: 0 | |
TASK 1.SubTask: 1 | |
TASK 2.SubSubTask: Completed. | |
TASK 2.SubTask: 2 | |
====Frame 6==== | |
FIBER.SubTask: Completed. | |
TASK 0.SubTask: Completed. | |
TASK 1.SubTask: 0 | |
TASK 2.SubTask: 1 | |
TASK 3.SubSubTask: Completed. | |
TASK 3.SubTask: 2 | |
====Frame 7==== | |
TASK 1.SubTask: Completed. | |
TASK 2.SubTask: 0 | |
TASK 3.SubTask: 1 | |
====Frame 8==== | |
FIBER.Body: 0 | |
TASK 0.Body: 0 | |
TASK 2.SubTask: Completed. | |
TASK 3.SubTask: 0 | |
====Frame 9==== | |
FIBER.Body: 1 | |
TASK 0.Body: 1 | |
TASK 1.Body: 0 | |
TASK 3.SubTask: Completed. | |
====Frame 10==== | |
FIBER.Body: 2 | |
TASK 0.Body: 2 | |
TASK 1.Body: 1 | |
TASK 2.Body: 0 | |
====Frame 11==== | |
FIBER.Body: 3 | |
TASK 0.Body: 3 | |
TASK 1.Body: 2 | |
TASK 2.Body: 1 | |
TASK 3.Body: 0 | |
====Frame 12==== | |
FIBER.Body: 4 | |
TASK 0.Body: 4 | |
TASK 1.Body: 3 | |
TASK 2.Body: 2 | |
TASK 3.Body: 1 | |
====Frame 13==== | |
FIBER complete. result = 42 | |
TASK 1.Body: 4 | |
TASK 2.Body: 3 | |
TASK 3.Body: 2 | |
TASK 0 Complete: TASK 0 complete. result = 42 | |
====Frame 14==== | |
TASK 2.Body: 4 | |
TASK 3.Body: 3 | |
TASK 1 Complete: TASK 1 complete. result = 42 | |
====Frame 15==== | |
TASK 3.Body: 4 | |
TASK 2 Complete: TASK 2 complete. result = 42 | |
====Frame 16==== | |
TASK 3 Complete: TASK 3 complete. result = 42 | |
TASK 3 Complete: (void) | |
====Frame 17==== | |
====Frame 18==== | |
====Frame 19==== | |
Good-bye World! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment