-
-
Save MichaEiler/99c3ed529d4fd19c4289fd04672a1a7c to your computer and use it in GitHub Desktop.
coroutine based thread-pool
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <atomic> | |
struct fire_once_event | |
{ | |
void set() | |
{ | |
m_flag.test_and_set(); | |
m_flag.notify_all(); | |
} | |
void wait() | |
{ | |
m_flag.wait(false); | |
} | |
private: | |
std::atomic_flag m_flag; | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "task.hpp" | |
#include "threadpool.hpp" | |
#include "sync_wait.hpp" | |
#include <iostream> | |
#include <thread> | |
task run_async_print(threadpool& pool) | |
{ | |
co_await pool.schedule(); | |
std::cout << "This is a hello from thread: " << std::this_thread::get_id() << "\n"; | |
} | |
int main() | |
{ | |
std::cout << "The main thread id is: " << std::this_thread::get_id() << "\n"; | |
threadpool pool{8}; | |
task t = run_async_print(pool); | |
sync_wait(t); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CXX=g++ | |
all: | |
$(CXX) main.cpp -fcoroutines -std=c++20 -pthread -g -o coro-sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <coroutine> | |
#include "fire_once_event.hpp" | |
#include "task.hpp" | |
struct sync_wait_task_promise; | |
struct [[nodiscard]] sync_wait_task | |
{ | |
using promise_type = sync_wait_task_promise; | |
sync_wait_task(std::coroutine_handle<sync_wait_task_promise> coro) | |
: m_handle(coro) | |
{ | |
} | |
~sync_wait_task() | |
{ | |
if (m_handle) | |
{ | |
m_handle.destroy(); | |
} | |
} | |
void run(fire_once_event& event); | |
private: | |
std::coroutine_handle<sync_wait_task_promise> m_handle; | |
}; | |
struct sync_wait_task_promise | |
{ | |
std::suspend_always initial_suspend() const noexcept { return {}; } | |
auto final_suspend() const noexcept | |
{ | |
struct awaiter | |
{ | |
bool await_ready() const noexcept { return false; } | |
void await_suspend(std::coroutine_handle<sync_wait_task_promise> coro) const noexcept | |
{ | |
fire_once_event *const event = coro.promise().m_event; | |
if (event) | |
{ | |
event->set(); | |
} | |
} | |
void await_resume() noexcept {} | |
}; | |
return awaiter(); | |
} | |
fire_once_event *m_event = nullptr; | |
sync_wait_task get_return_object() noexcept | |
{ | |
return sync_wait_task{ std::coroutine_handle<sync_wait_task_promise>::from_promise(*this) }; | |
} | |
void unhandled_exception() noexcept { exit(1); } | |
}; | |
inline void sync_wait_task::run(fire_once_event& event) | |
{ | |
m_handle.promise().m_event = &event; | |
m_handle.resume(); | |
} | |
inline sync_wait_task make_sync_wait_task(task& t) | |
{ | |
co_await t; | |
} | |
inline void sync_wait(task& t) | |
{ | |
fire_once_event event; | |
auto wait_task = make_sync_wait_task(t); | |
wait_task.run(event); | |
event.wait(); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <coroutine> | |
#include <exception> | |
#include <iostream> | |
#include <utility> | |
class [[nodiscard]] task; | |
struct task_promise | |
{ | |
struct final_awaitable | |
{ | |
bool await_ready() const noexcept { return false; } | |
std::coroutine_handle<> await_suspend(std::coroutine_handle<task_promise> coro) noexcept | |
{ | |
return coro.promise().m_continuation; | |
} | |
void await_resume() noexcept {} | |
}; | |
task get_return_object() noexcept; | |
std::suspend_always initial_suspend() const noexcept { return {}; } | |
auto final_suspend() const noexcept { return final_awaitable(); } | |
void return_void() noexcept {} | |
void unhandled_exception() noexcept { exit(1); } | |
void set_continuation(std::coroutine_handle<> continuation) noexcept | |
{ | |
m_continuation = continuation; | |
} | |
private: | |
std::coroutine_handle<> m_continuation = std::noop_coroutine(); | |
}; | |
class [[nodiscard]] task | |
{ | |
public: | |
using promise_type = task_promise; | |
explicit task(std::coroutine_handle<task_promise> handle) | |
: m_handle(handle) | |
{ | |
} | |
~task() | |
{ | |
if (m_handle) | |
{ | |
m_handle.destroy(); | |
} | |
} | |
auto operator co_await() noexcept | |
{ | |
struct awaiter | |
{ | |
bool await_ready() const noexcept { return !m_coro || m_coro.done(); } | |
std::coroutine_handle<> await_suspend( std::coroutine_handle<> awaiting_coroutine) noexcept { | |
m_coro.promise().set_continuation(awaiting_coroutine); | |
return m_coro; | |
} | |
void await_resume() noexcept {} | |
std::coroutine_handle<task_promise> m_coro; | |
}; | |
return awaiter{m_handle}; | |
} | |
private: | |
std::coroutine_handle<task_promise> m_handle; | |
}; | |
inline task task_promise::get_return_object() noexcept | |
{ | |
return task{ std::coroutine_handle<task_promise>::from_promise(*this) }; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <condition_variable> | |
#include <coroutine> | |
#include <cstdint> | |
#include <list> | |
#include <mutex> | |
#include <queue> | |
#include <thread> | |
class threadpool | |
{ | |
public: | |
explicit threadpool(const std::size_t threadCount) | |
{ | |
for (std::size_t i = 0; i < threadCount; ++i) | |
{ | |
std::thread worker_thread([this]() { | |
this->thread_loop(); | |
}); | |
m_threads.push_back(std::move(worker_thread)); | |
} | |
} | |
~threadpool() | |
{ | |
shutdown(); | |
} | |
auto schedule() | |
{ | |
struct awaiter | |
{ | |
threadpool* m_threadpool; | |
constexpr bool await_ready() const noexcept { return false; } | |
constexpr void await_resume() const noexcept { } | |
void await_suspend(std::coroutine_handle<> coro) const noexcept { | |
m_threadpool->enqueue_task(coro); | |
} | |
}; | |
return awaiter{this}; | |
} | |
private: | |
std::list<std::thread> m_threads; | |
std::mutex m_mutex; | |
std::condition_variable m_cond; | |
std::queue<std::coroutine_handle<>> m_coros; | |
bool m_stop_thread = false; | |
void thread_loop() | |
{ | |
while (!m_stop_thread) | |
{ | |
std::unique_lock<std::mutex> lock(m_mutex); | |
while (!m_stop_thread && m_coros.size() == 0) | |
{ | |
m_cond.wait_for(lock, std::chrono::microseconds(100)); | |
} | |
if (m_stop_thread) | |
{ | |
break; | |
} | |
auto coro = m_coros.front(); | |
m_coros.pop(); | |
coro.resume(); | |
} | |
} | |
void enqueue_task(std::coroutine_handle<> coro) noexcept { | |
std::unique_lock<std::mutex> lock(m_mutex); | |
m_coros.emplace(coro); | |
m_cond.notify_one(); | |
} | |
void shutdown() | |
{ | |
m_stop_thread = true; | |
while (m_threads.size() > 0) | |
{ | |
std::thread& thread = m_threads.back(); | |
if (thread.joinable()) | |
{ | |
thread.join(); | |
} | |
m_threads.pop_back(); | |
} | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment