Created
January 30, 2022 18:43
-
-
Save stormouse/abda9e9ed062ef0f69d30346b09cc5d6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Code mostly from https://lewissbaker.github.io/2017/11/17/understanding-operator-co-await | |
#include <atomic> | |
#include <chrono> | |
#include <experimental/coroutine> | |
#include <iostream> | |
#include <optional> | |
#include <thread> | |
// awaitable ManualResetEvent | |
// usage: co_await event; | |
class ManualResetEvent { | |
public: | |
ManualResetEvent(bool initiallySet = false) noexcept | |
: state_{ initiallySet ? this : nullptr } {} | |
bool isSet() const noexcept { | |
return state_.load(std::memory_order_acquire) == this; | |
} | |
struct awaiter; | |
awaiter operator co_await() const noexcept; | |
void set() noexcept; | |
void reset() noexcept { | |
void *oldValue = this; | |
state_.compare_exchange_strong(oldValue, nullptr, std::memory_order_acquire); | |
} | |
private: | |
friend struct awaiter; | |
mutable std::atomic<void*> state_; | |
private: | |
ManualResetEvent(const ManualResetEvent&) = delete; | |
ManualResetEvent(ManualResetEvent&&) = delete; | |
ManualResetEvent& operator=(const ManualResetEvent&) = delete; | |
ManualResetEvent& operator=(ManualResetEvent&&) = delete; | |
}; | |
using Handle = std::experimental::coroutine_handle<>; | |
struct ManualResetEvent::awaiter { | |
awaiter(const ManualResetEvent& event) noexcept | |
: event_{ event } {} | |
bool await_ready() const noexcept { | |
return event_.isSet(); | |
} | |
bool await_suspend(Handle awaitingCoroutine) noexcept { | |
const void* const setState = &event_; | |
awaitingCoroutine_ = awaitingCoroutine; | |
void *oldValue = event_.state_.load(std::memory_order_acquire); | |
do { | |
if (oldValue == setState) return false; | |
next_ = static_cast<awaiter*>(oldValue); | |
} while(!event_.state_.compare_exchange_weak( | |
oldValue, | |
this, | |
std::memory_order_release, | |
std::memory_order_acquire)); | |
return true; | |
} | |
void await_resume() noexcept {} | |
const ManualResetEvent& event_; | |
private: | |
friend ManualResetEvent; | |
Handle awaitingCoroutine_; | |
awaiter* next_; | |
}; | |
void ManualResetEvent::set() noexcept { | |
void* oldValue = state_.exchange(this, std::memory_order_acq_rel); | |
if (oldValue != this) { | |
auto* waiters = static_cast<awaiter*>(oldValue); | |
while (waiters != nullptr) { | |
auto* next = waiters->next_; | |
waiters->awaitingCoroutine_.resume(); | |
waiters = next; | |
} | |
} | |
} | |
ManualResetEvent::awaiter | |
ManualResetEvent::operator co_await() const noexcept { | |
return awaiter{ *this }; | |
} | |
struct Task { | |
struct promise_type { | |
using Handle = std::experimental::coroutine_handle<promise_type>; | |
Task get_return_object() { | |
return Task {Handle::from_promise(*this)}; | |
} | |
std::experimental::suspend_never initial_suspend() { return {}; } | |
std::experimental::suspend_never final_suspend() noexcept { return {}; } | |
ManualResetEvent::awaiter await_transform(const ManualResetEvent::awaiter& a) { | |
std::cout << "await_transform on " << | |
std::this_thread::get_id() << | |
std::endl; | |
return {a.event_}; | |
} | |
void return_void() {} | |
void unhandled_exception() {} | |
}; | |
explicit Task(promise_type::Handle coro) : coro_{coro} {} | |
private: | |
promise_type::Handle coro_; | |
}; | |
int value; | |
ManualResetEvent event; | |
void producer() { | |
value = 10; | |
std::cout << "produced: " << value << | |
" on " << std::this_thread::get_id() | |
<< std::endl; | |
event.set(); | |
} | |
Task consumer() { | |
co_await event; | |
std::cout << "consumed: " << value << | |
" on " << std::this_thread::get_id() | |
<< std::endl; | |
} | |
using namespace std::chrono_literals; | |
int main() { | |
auto c1 = consumer(); | |
auto c2 = consumer(); | |
std::this_thread::sleep_for(100ms); | |
auto producerThread = std::thread(producer); | |
producerThread.join(); | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FLAGS = -std=c++20 -stdlib=libc++ -fcoroutines-ts -lpthread | |
INC = -I/usr/lib/llvm-13/include/c++/v1 | |
coroutine: coroutine.o | |
clang++ $(FLAGS) $(INC) coroutine.o -o coroutine | |
coroutine.o: coroutine.cpp | |
clang++ $(FLAGS) $(INC) -c coroutine.cpp -o coroutine.o | |
clean: | |
rm *.o coroutine |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment