Skip to content

Instantly share code, notes, and snippets.

@MichaEiler
Created May 17, 2021 19:03
Show Gist options
  • Save MichaEiler/99c3ed529d4fd19c4289fd04672a1a7c to your computer and use it in GitHub Desktop.
Save MichaEiler/99c3ed529d4fd19c4289fd04672a1a7c to your computer and use it in GitHub Desktop.
coroutine based thread-pool
#pragma once
#include <atomic>
struct fire_once_event
{
void set()
{
m_flag.test_and_set();
m_flag.notify_all();
}
void wait()
{
m_flag.wait(false);
}
private:
std::atomic_flag m_flag;
};
#include "task.hpp"
#include "threadpool.hpp"
#include "sync_wait.hpp"
#include <iostream>
#include <thread>
task run_async_print(threadpool& pool)
{
co_await pool.schedule();
std::cout << "This is a hello from thread: " << std::this_thread::get_id() << "\n";
}
int main()
{
std::cout << "The main thread id is: " << std::this_thread::get_id() << "\n";
threadpool pool{8};
task t = run_async_print(pool);
sync_wait(t);
}
CXX=g++
all:
$(CXX) main.cpp -fcoroutines -std=c++20 -pthread -g -o coro-sample
#pragma once
#include <coroutine>
#include "fire_once_event.hpp"
#include "task.hpp"
struct sync_wait_task_promise;
struct [[nodiscard]] sync_wait_task
{
using promise_type = sync_wait_task_promise;
sync_wait_task(std::coroutine_handle<sync_wait_task_promise> coro)
: m_handle(coro)
{
}
~sync_wait_task()
{
if (m_handle)
{
m_handle.destroy();
}
}
void run(fire_once_event& event);
private:
std::coroutine_handle<sync_wait_task_promise> m_handle;
};
struct sync_wait_task_promise
{
std::suspend_always initial_suspend() const noexcept { return {}; }
auto final_suspend() const noexcept
{
struct awaiter
{
bool await_ready() const noexcept { return false; }
void await_suspend(std::coroutine_handle<sync_wait_task_promise> coro) const noexcept
{
fire_once_event *const event = coro.promise().m_event;
if (event)
{
event->set();
}
}
void await_resume() noexcept {}
};
return awaiter();
}
fire_once_event *m_event = nullptr;
sync_wait_task get_return_object() noexcept
{
return sync_wait_task{ std::coroutine_handle<sync_wait_task_promise>::from_promise(*this) };
}
void unhandled_exception() noexcept { exit(1); }
};
inline void sync_wait_task::run(fire_once_event& event)
{
m_handle.promise().m_event = &event;
m_handle.resume();
}
inline sync_wait_task make_sync_wait_task(task& t)
{
co_await t;
}
inline void sync_wait(task& t)
{
fire_once_event event;
auto wait_task = make_sync_wait_task(t);
wait_task.run(event);
event.wait();
}
#pragma once
#include <coroutine>
#include <exception>
#include <iostream>
#include <utility>
class [[nodiscard]] task;
struct task_promise
{
struct final_awaitable
{
bool await_ready() const noexcept { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<task_promise> coro) noexcept
{
return coro.promise().m_continuation;
}
void await_resume() noexcept {}
};
task get_return_object() noexcept;
std::suspend_always initial_suspend() const noexcept { return {}; }
auto final_suspend() const noexcept { return final_awaitable(); }
void return_void() noexcept {}
void unhandled_exception() noexcept { exit(1); }
void set_continuation(std::coroutine_handle<> continuation) noexcept
{
m_continuation = continuation;
}
private:
std::coroutine_handle<> m_continuation = std::noop_coroutine();
};
class [[nodiscard]] task
{
public:
using promise_type = task_promise;
explicit task(std::coroutine_handle<task_promise> handle)
: m_handle(handle)
{
}
~task()
{
if (m_handle)
{
m_handle.destroy();
}
}
auto operator co_await() noexcept
{
struct awaiter
{
bool await_ready() const noexcept { return !m_coro || m_coro.done(); }
std::coroutine_handle<> await_suspend( std::coroutine_handle<> awaiting_coroutine) noexcept {
m_coro.promise().set_continuation(awaiting_coroutine);
return m_coro;
}
void await_resume() noexcept {}
std::coroutine_handle<task_promise> m_coro;
};
return awaiter{m_handle};
}
private:
std::coroutine_handle<task_promise> m_handle;
};
inline task task_promise::get_return_object() noexcept
{
return task{ std::coroutine_handle<task_promise>::from_promise(*this) };
}
#pragma once
#include <condition_variable>
#include <coroutine>
#include <cstdint>
#include <list>
#include <mutex>
#include <queue>
#include <thread>
class threadpool
{
public:
explicit threadpool(const std::size_t threadCount)
{
for (std::size_t i = 0; i < threadCount; ++i)
{
std::thread worker_thread([this]() {
this->thread_loop();
});
m_threads.push_back(std::move(worker_thread));
}
}
~threadpool()
{
shutdown();
}
auto schedule()
{
struct awaiter
{
threadpool* m_threadpool;
constexpr bool await_ready() const noexcept { return false; }
constexpr void await_resume() const noexcept { }
void await_suspend(std::coroutine_handle<> coro) const noexcept {
m_threadpool->enqueue_task(coro);
}
};
return awaiter{this};
}
private:
std::list<std::thread> m_threads;
std::mutex m_mutex;
std::condition_variable m_cond;
std::queue<std::coroutine_handle<>> m_coros;
bool m_stop_thread = false;
void thread_loop()
{
while (!m_stop_thread)
{
std::unique_lock<std::mutex> lock(m_mutex);
while (!m_stop_thread && m_coros.size() == 0)
{
m_cond.wait_for(lock, std::chrono::microseconds(100));
}
if (m_stop_thread)
{
break;
}
auto coro = m_coros.front();
m_coros.pop();
coro.resume();
}
}
void enqueue_task(std::coroutine_handle<> coro) noexcept {
std::unique_lock<std::mutex> lock(m_mutex);
m_coros.emplace(coro);
m_cond.notify_one();
}
void shutdown()
{
m_stop_thread = true;
while (m_threads.size() > 0)
{
std::thread& thread = m_threads.back();
if (thread.joinable())
{
thread.join();
}
m_threads.pop_back();
}
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment