Skip to content

Instantly share code, notes, and snippets.

@YexuanXiao
Last active October 20, 2024 11:06
Show Gist options
  • Save YexuanXiao/abad460805c66eb66db883693d8b2f4d to your computer and use it in GitHub Desktop.
Save YexuanXiao/abad460805c66eb66db883693d8b2f4d to your computer and use it in GitHub Desktop.
coroutine
#include <algorithm>
#include <atomic>
#include <cassert>
#include <coroutine>
#include <exception>
#include <functional>
#include <mutex>
#include <queue>
#include <semaphore>
#include <thread>
#include <type_traits>
#include <vector>
// clang-format off
#if !defined(__cpp_size_t_suffix) || __cpp_size_t_suffix < 202011L
inline constexpr auto operator""uz(unsigned long long const value)
{
return std::size_t{ value };
}
inline constexpr auto operator""z(unsigned long long const value)
{
return std::ptrdiff_t(value);
}
#endif
// clang-format on
namespace bizwen
{
class thread_pool
{
class mutex
{
std::atomic<int> s_{};
public:
void lock() noexcept
{
while (s_.exchange(1, std::memory_order::acquire))
s_.wait(1, std::memory_order::relaxed);
}
bool try_lock() noexcept
{
return !s_.exchange(1, std::memory_order::acquire);
}
void unlock() noexcept
{
s_.store(0, std::memory_order::release);
s_.notify_one();
}
};
class wthread
{
using vector = std::vector<std::coroutine_handle<>>;
std::counting_semaphore<> s_{ 0z }; // 信号量激活线程
std::jthread t_{}; // 运行线程
vector w_{}; // 待执行队列
mutex m_{}; // 保护队列
bool r_{}; // 是否需要还回就绪队列
public:
wthread() = default;
bool operator==(std::thread::id id)
{
return t_.get_id() == id;
}
void join() noexcept
{
t_.join();
}
// 线程安全的发送任务到线程
void push_back(std::coroutine_handle<> c)
{
{
std::lock_guard g{ m_ };
w_.push_back(c);
}
s_.release();
}
// 仅用于清理
void wake() noexcept
{
s_.release();
}
// 发起任务并告知已经从就绪队列中移除
void push_back_and_add(std::coroutine_handle<> c)
{
{
std::lock_guard g{ m_ };
w_.push_back(c);
r_ = true; // 必须出现在release前,push_back后
}
s_.release();
}
void consume(thread_pool& pool) noexcept
{
while (true)
{
s_.acquire();
if (pool.is_exit_())
break;
m_.lock();
auto t = w_.front();
std::copy(w_.begin() + 1, w_.end(), w_.begin());
w_.pop_back();
m_.unlock();
t();
if (r_) // 如果该线程从就绪队列移除了,在执行后重新添加它
{
r_ = false;
std::lock_guard lock{ pool.pending_mutex_ };
pool.pending_list_.push_back(this); // 不会失败
// 反向唤醒派发线程
pool.priority_waiter_.release();
}
}
}
void start(thread_pool& pool)
{
t_ = std::jthread([this, &pool] { consume(pool); });
}
};
struct priority_task
{
std::coroutine_handle<> handle;
std::size_t priority;
std::strong_ordering operator<=>(priority_task const& rhs) const noexcept
{
return priority <=> rhs.priority;
}
};
struct lazy_task
{
std::coroutine_handle<> handle;
std::chrono::steady_clock::time_point time;
std::strong_ordering operator<=>(lazy_task const& rhs) const noexcept
{
return rhs.time <=> time;
}
};
std::vector<wthread> work_threads_;
std::vector<wthread*> pending_list_;
std::priority_queue<lazy_task> lazy_queue_;
std::priority_queue<priority_task> priority_queue_;
std::jthread lazy_thread_;
std::jthread priority_thread_;
std::counting_semaphore<> lazy_waiter_{ 0z };
std::counting_semaphore<> priority_waiter_{ 0z };
mutex lazy_mutex_;
mutex priority_mutex_;
mutex pending_mutex_;
bool exit_flag_{};
// 尝试将线程从就绪列表中取出并且发送任务
bool try_run_(std::coroutine_handle<> handle)
{
std::lock_guard lock{ pending_mutex_ };
if (pending_list_.empty())
return false;
auto& back = *pending_list_.back();
back.push_back_and_add(handle);
pending_list_.pop_back();
return true;
}
bool is_exit_() const noexcept
{
return exit_flag_;
}
void priority_loop_()
{
while (true)
{
priority_waiter_.acquire();
if (is_exit_())
break;
std::lock_guard lock{ priority_mutex_ };
// 由于来自工作线程的过通知,可能存在容器为空的情况
if (priority_queue_.empty())
continue;
if (try_run_(priority_queue_.top().handle))
priority_queue_.pop();
}
}
void lazy_loop_()
{
while (true)
{
lazy_waiter_.acquire();
if (is_exit_())
break;
std::unique_lock lock{ lazy_mutex_ };
// 复制
// 如果到时间,则立即执行任务
if (auto task{ lazy_queue_.top() }; task.time < decltype(task.time)::clock::now())
{
lazy_queue_.pop();
// 仅在队列还剩余有元素时才通知下一轮
if (!lazy_queue_.empty())
lazy_waiter_.release();
lock.unlock();
// 到此为止,说明该任务已到时间适合执行
if (auto h = task.handle; !try_run_(h))
run_once(h, std::size_t(-1));
}
else
{
// 如果未到时间,则等待到时间或被通知
// 先无条件释放锁,防止阻塞插入任务
lock.unlock();
// 并且此时count<=size-1,wait为0
// 尝试等待到指定时间,如果返回true,说明有多个任务,此时将count消耗为0,发生等待
(void)lazy_waiter_.try_acquire_until(task.time);
// 如果中途返回,代表新任务被插入
// 无论返回什么,都无条件release 1
// 因为新信号量通知会被立即消费
lazy_waiter_.release();
// 直接进入下一轮
}
}
}
public:
class context
{
std::thread::id tid_;
friend thread_pool;
std::thread::id id() const noexcept
{
return tid_;
}
context(std::thread::id tid) noexcept : tid_(tid)
{
}
bool operator==(std::thread::id tid) const noexcept
{
return tid_ == tid;
}
public:
context() = default;
context(const context&) = default;
context& operator=(const context&) = default;
bool operator==(const context& c) const = default;
};
// 仅能在线程池线程中调用
static context capture_context() noexcept
{
return context{ std::this_thread::get_id() };
}
void run_once(std::coroutine_handle<> callback, std::size_t priority = 0uz)
{
std::lock_guard lock{ priority_mutex_ };
priority_queue_.emplace(callback, priority);
priority_waiter_.release();
}
void run_after(std::coroutine_handle<> callback, std::chrono::milliseconds duration)
{
auto time{ std::chrono::steady_clock::now() + duration };
std::lock_guard lock{ lazy_mutex_ };
lazy_queue_.emplace(callback, time);
lazy_waiter_.release();
}
void run_in(std::coroutine_handle<> callback, context ctx)
{
for (auto& i : work_threads_)
{
if (i == ctx.tid_)
{
i.push_back(callback);
return;
}
}
std::abort();
}
public:
void exit() noexcept
{
exit_flag_ = true;
lazy_waiter_.release();
priority_waiter_.release();
for (auto& i : work_threads_)
i.wake();
}
thread_pool(std::size_t num = 0uz)
: work_threads_(std::max<std::size_t>({ std::thread::hardware_concurrency(), num, 2uz }))
{
for (auto& i : work_threads_)
{
pending_list_.push_back(&i);
i.start(*this);
}
lazy_thread_ = std::jthread{ [this]() { lazy_loop_(); } };
priority_thread_ = std::jthread{ [this]() { priority_loop_(); } };
}
~thread_pool()
{
exit();
priority_thread_.join();
lazy_thread_.join();
for (auto& i : work_threads_)
i.join();
}
};
} // namespace bizwen
bizwen::thread_pool pool{};
namespace bizwen
{
class canceled_coroutine
{
};
namespace detail
{
class cancelable_promise_base
{
enum class status
{
ready = 0,
canceled,
next,
done
};
std::atomic<unsigned int> rfcnt_{ 1u };
std::atomic<status> st_;
std::coroutine_handle<> next_;
std::exception_ptr exc_;
public:
// 积极启动
std::suspend_never initial_suspend() const noexcept
{
return {};
}
void unhandled_exception() noexcept
{
exc_ = std::current_exception();
}
void increase() noexcept
{
rfcnt_.fetch_add(1, std::memory_order::relaxed);
}
bool zero() noexcept
{
return rfcnt_.fetch_sub(1, std::memory_order::acq_rel) == 0u;
}
// 是否被取消
bool canceled() const noexcept
{
return st_.load(std::memory_order::relaxed) == status::canceled;
}
// 尝试取消当前协程,无论返回什么
bool cancel() noexcept
{
// 尝试将ready和next转换为canceled
auto st = status::ready;
if (st_.compare_exchange_strong(st, status::canceled, std::memory_order::acq_rel))
return true;
st = status::next;
if (st_.compare_exchange_strong(st, status::canceled, std::memory_order::acq_rel))
return true;
assert(st != status::canceled); // 不允许任务被取消两次
return false;
}
auto cancel_async() noexcept
{
struct cancel_awaiter : public std::suspend_always
{
cancelable_promise_base& p_;
bool await_suspend(std::coroutine_handle<> handle)
{
p_.next_ = handle;
auto st = status::ready;
if (p_.st_.compare_exchange_strong(st, status::canceled, std::memory_order::acq_rel))
return true;
assert(st != status::canceled); // 不允许任务被取消两次
assert(st != status::next); // 不允许同一任务被两次等待
return false;
}
};
return cancel_awaiter{ .p_ = *this };
}
bool is_done() const noexcept
{
return st_.load(std::memory_order::relaxed) == status::done;
}
// 被task_awaiter::await_suspend调用
// 返回false代表可以直接执行当前协程,也就是await_suspend返回false
bool next(std::coroutine_handle<> handle) noexcept
{
next_ = handle;
// 只有ready状态才可以被替换为next
auto st = status::ready;
if (st_.compare_exchange_strong(st, status::next, std::memory_order::acq_rel))
return true; // 当前协程会被暂停
assert(st != status::next); // 不允许多次co_await同一个task
return false;
}
// 被task_awaiter::await_resume调用
void rethrow_exception() const
{
// 优先抛出协程体的异常
if (exc_)
std::rethrow_exception(exc_);
auto s = st_.load(std::memory_order::relaxed);
// 然后测试协程是否被取消
if (s == status::canceled)
throw canceled_coroutine{};
}
auto final_suspend() noexcept
{
class final_awaiter : public std::suspend_always
{
cancelable_promise_base& p_;
friend cancelable_promise_base;
public:
final_awaiter(cancelable_promise_base& p) noexcept : p_(p)
{
}
bool await_suspend(std::coroutine_handle<>) const noexcept
{
auto st = status::ready;
p_.st_.compare_exchange_strong(st, status::done, std::memory_order::acq_rel);
// 立即执行或者发送到线程池
if (auto next = p_.next_)
next();
// 恢复协程或销毁
return !p_.zero();
}
};
return final_awaiter{ *this };
}
};
template <typename T>
requires std::derived_from<T, cancelable_promise_base>
auto to_cancelable(std::coroutine_handle<T> h) noexcept
{
return std::coroutine_handle<cancelable_promise_base>::from_address(h.address());
}
class progress_promise_base : public cancelable_promise_base
{
protected:
std::size_t progress{};
};
struct cancellation_awaiter : std::suspend_always
{
private:
cancelable_promise_base* c_{};
public:
template <typename T>
bool await_suspend(std::coroutine_handle<T> h) noexcept
{
c_ = &to_cancelable(h).promise();
return false;
}
auto await_resume() const noexcept
{
class token
{
friend cancellation_awaiter;
cancelable_promise_base& c_;
token(cancelable_promise_base& c) noexcept : c_(c)
{
}
public:
bool canceled() const noexcept
{
return c_.canceled();
}
explicit operator bool() const noexcept
{
return canceled();
}
};
return token{ *c_ };
}
};
template <typename U>
struct enable_cancellation : public std::suspend_always
{
template <typename T>
void await_suspend(std::coroutine_handle<T> handle) const
requires requires {to_cancelable(handle); }
{
if (handle.promise().canceled())
throw canceled_coroutine{};
return static_cast<const U&>(*this).await_suspend(std::coroutine_handle<>{handle});
}
};
struct timer_awaiter : public enable_cancellation<timer_awaiter>
{
std::chrono::milliseconds d_;
bool await_ready() const noexcept
{
return d_ <= decltype(d_)::zero();
}
using enable_cancellation::await_suspend;
void await_suspend(std::coroutine_handle<> handle) const
{
pool.run_after(handle, d_);
}
};
struct apartment_awaiter : public enable_cancellation<apartment_awaiter>
{
thread_pool::context c_;
bool await_ready() const noexcept
{
return thread_pool::capture_context() == c_ || c_ == thread_pool::context{};
}
using enable_cancellation::await_suspend;
void await_suspend(std::coroutine_handle<> handle) const
{
pool.run_in(handle, c_);
}
};
struct background_awaiter : public enable_cancellation<background_awaiter>
{
using enable_cancellation::await_suspend;
void await_suspend(std::coroutine_handle<> handle) const
{
pool.run_once(handle);
}
};
} // namespace detail
auto resume_background() noexcept
{
return detail::background_awaiter{};
}
auto get_cancellation_token() noexcept
{
return detail::cancellation_awaiter{};
}
auto operator co_await(thread_pool::context c) noexcept
{
return detail::apartment_awaiter{ .c_ = c };
}
auto capture_apartment() noexcept
{
return thread_pool::capture_context();
}
template <class Rep, class Period> // breakline
auto operator co_await(std::chrono::duration<Rep, Period> d) noexcept
{
return detail::timer_awaiter{ .d_ = std::chrono::duration_cast<std::chrono::milliseconds>(d) };
}
template <typename T = void>
class task
{
static_assert(std::same_as<std::remove_cvref_t<std::decay_t<T>>, T>);
public:
class promise_type;
private:
// 为了实现移动task,必须使用handle储存
std::coroutine_handle<promise_type> handle_{};
friend promise_type;
task(promise_type& p) noexcept : handle_(decltype(handle_)::from_promise(p))
{
}
public:
task() = default;
task(task&& rhs) noexcept
{
std::swap(rhs.handle_, handle_);
}
task& operator=(task&& rhs) noexcept
{
std::swap(rhs.handle_, handle_);
return *this;
}
task(task const& rhs) noexcept
{
if (&rhs == this)
return;
if (!rhs.handle_)
return;
handle_ = rhs.handle_;
handle_.promise().increase();
}
task& operator=(const task& rhs) noexcept
{
if (&rhs == this)
return *this;
task temp{};
temp.handle_ = handle_;
handle_ = rhs.handle_;
if (!rhs.handle_)
return *this;
handle_.promise().increase();
return *this;
}
class promise_type : public detail::cancelable_promise_base
{
friend task;
std::optional<T> result_;
public:
promise_type() noexcept
{
}
task get_return_object() noexcept
{
return { *this };
}
void return_value(T&& t) noexcept
{
result_ = std::move(t);
}
void return_value(T const& t)
{
result_ = t;
}
auto& result() noexcept
{
return result_.value();
}
};
void cancel() noexcept
{
assert(handle_);
handle_.promise().cancel();
}
auto cancel_async() noexcept
{
assert(handle_);
return handle_.promise().cancel_async();
}
~task()
{
if (!handle_)
return;
if (handle_.promise().zero())
handle_.destroy();
}
auto operator co_await() noexcept
{
assert(handle_);
struct task_awaiter
{
promise_type& p_;
bool await_ready() const noexcept
{
return p_.is_done();
}
bool await_suspend(std::coroutine_handle<> handle) const noexcept
{
return p_.next(handle);
}
T await_resume()
{
p_.rethrow_exception();
return std::move(static_cast<promise_type&>(p_).result());
}
};
return task_awaiter{ .p_ = handle_.promise() };
}
T sync_get()
{
std::atomic<int> flag{ 1 };
struct sync_awaiter : public std::suspend_always
{
decltype(handle_) h_;
auto await_suspend(std::coroutine_handle<> handle)
{
return h_.promise().next(handle);
}
};
[&flag, this]() noexcept -> task {
co_await sync_awaiter{ .h_ = handle_ };
flag.store(0, std::memory_order::release);
flag.notify_one();
}();
while (flag.exchange(1, std::memory_order::acquire))
flag.wait(1, std::memory_order::relaxed);
auto& p = handle_.promise();
p.rethrow_exception();
return std::move(p.result());
}
};
template <> // breakline
class task<void>
{
public:
class promise_type;
private:
// 为了实现移动task,必须使用handle储存
std::coroutine_handle<promise_type> handle_{};
friend promise_type;
task(promise_type& p) noexcept : handle_(decltype(handle_)::from_promise(p))
{
}
public:
task() = default;
task(task&& rhs) noexcept
{
std::swap(rhs.handle_, handle_);
}
task& operator=(task&& rhs) noexcept
{
std::swap(rhs.handle_, handle_);
return *this;
}
task(task const& rhs) noexcept
{
if (&rhs == this)
return;
if (!rhs.handle_)
return;
handle_ = rhs.handle_;
handle_.promise().increase();
}
task& operator=(const task& rhs) noexcept
{
if (&rhs == this)
return *this;
task temp{};
temp.handle_ = handle_;
handle_ = rhs.handle_;
if (!rhs.handle_)
return *this;
handle_.promise().increase();
return *this;
}
class promise_type : public detail::cancelable_promise_base
{
friend task;
public:
promise_type() noexcept
{
}
task get_return_object() noexcept
{
return { *this };
}
void return_void() const noexcept
{
}
};
void cancel() noexcept
{
assert(handle_);
handle_.promise().cancel();
}
auto cancel_async() noexcept
{
assert(handle_);
return handle_.promise().cancel_async();
}
~task()
{
if (!handle_)
return;
if (handle_.promise().zero())
handle_.destroy();
}
auto operator co_await() noexcept
{
assert(handle_);
struct task_awaiter
{
promise_type& p_;
bool await_ready() const noexcept
{
return p_.is_done();
}
bool await_suspend(std::coroutine_handle<> handle) const noexcept
{
return p_.next(handle);
}
void await_resume()
{
p_.rethrow_exception();
}
};
return task_awaiter{ .p_ = handle_.promise() };
}
void sync_get()
{
assert(handle_);
std::atomic<int> flag{ 1 };
struct sync_awaiter : public std::suspend_always
{
decltype(handle_) h_;
auto await_suspend(std::coroutine_handle<> handle) noexcept
{
return h_.promise().next(handle);
}
};
[&flag, this]() noexcept -> task {
co_await sync_awaiter{ .h_ = handle_ };
flag.store(0, std::memory_order::release);
flag.notify_one();
}();
while (flag.exchange(1, std::memory_order::acquire))
flag.wait(1, std::memory_order::relaxed);
handle_.promise().rethrow_exception();
}
};
struct fire_and_forget
{
struct promise_type
{
public:
promise_type() noexcept
{
}
fire_and_forget get_return_object() const noexcept
{
return {};
}
std::suspend_never initial_suspend() const noexcept
{
return {};
}
std::suspend_never final_suspend() const noexcept
{
return {};
}
void return_void() const noexcept
{
}
void unhandled_exception() const noexcept
{
}
};
auto operator co_await() const noexcept
{
return std::suspend_never{};
}
};
} // namespace bizwen
#include <fast_io.h>
/* https://learn.microsoft.com/zh-cn/windows/uwp/cpp-and-winrt-apis/concurrency-2
IAsyncAction ExplicitCancelationAsync()
{
auto cancelation_token{ co_await winrt::get_cancellation_token() };
while (!cancelation_token())
{
std::cout << "ExplicitCancelationAsync: do some work for 1 second" << std::endl;
co_await 1s;
}
}
IAsyncAction MainCoroutineAsync()
{
auto explicit_cancelation{ ExplicitCancelationAsync() };
co_await 3s;
explicit_cancelation.Cancel();
}
*/
bizwen::task<> ExplicitCancelationAsync()
{
using namespace bizwen;
using namespace std::chrono_literals;
auto cancelation_token{ co_await bizwen::get_cancellation_token() };
while (!cancelation_token)
{
fast_io::print(fast_io::err(), "ExplicitCancelationAsync: do some work for 1 second\n");
co_await 1s;
}
}
bizwen::task<> MainCoroutineAsync()
{
using namespace bizwen;
using namespace std::chrono_literals;
auto explicit_cancelation{ ExplicitCancelationAsync() };
co_await 3s;
explicit_cancelation.cancel();
}
bizwen::task<void> sleep(int x)
{
using namespace std::chrono_literals;
using namespace bizwen;
co_await(x * 1s);
fast_io::println(fast_io::out(), x);
}
bizwen::task<void> sleep_sort(auto... args)
{
for (auto& i : std::array{ sleep(args)... })
co_await i;
}
int main()
{
MainCoroutineAsync().sync_get();
sleep_sort(0, 9, 3, 4, 6, 1, 2, 8, 5, 7).sync_get();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment