Skip to content

Instantly share code, notes, and snippets.

@GaZaTu
Created January 30, 2020 16:40
Show Gist options
  • Save GaZaTu/00db67a99f0ea9609cd573f48d4ad308 to your computer and use it in GitHub Desktop.
Save GaZaTu/00db67a99f0ea9609cd573f48d4ad308 to your computer and use it in GitHub Desktop.
#pragma once
#include <exception>
#include <coroutine>
#include <atomic>
#include <functional>
#include <iostream>
template<typename T>
class future {
public:
class promise_type {
public:
class final_awaitable {
public:
bool await_ready() {
std::cout << "future<T>::promise_type::final_awaitable::await_ready" << std::endl;
return false;
}
template<typename PROMISE>
void await_suspend(std::coroutine_handle<PROMISE> coro) noexcept {
std::cout << "future<T>::promise_type::final_awaitable::await_suspend" << std::endl;
promise_type& promise = coro.promise();
// Use 'release' memory semantics in case we finish before the
// awaiter can suspend so that the awaiting thread sees our
// writes to the resulting value.
// Use 'acquire' memory semantics in case the caller registered
// the continuation before we finished. Ensure we see their write
// to m_continuation.
if (promise._continuation_state.exchange(true, std::memory_order_acq_rel)) {
promise._continuation.resume();
}
}
void await_resume() {
std::cout << "future<T>::promise_type::final_awaitable::await_resume" << std::endl;
}
};
T _value = nullptr;
std::exception_ptr _exception;
bool _failed = false;
promise_type() noexcept : _continuation_state(false) { }
~promise_type() {
if (_failed) {
_exception.~exception_ptr();
} else {
_value.~T();
}
}
auto get_return_object() {
return coro_handle::from_promise(*this);
}
// custom
auto get_awaitable() {
return future(get_return_object());
}
auto initial_suspend() noexcept {
return std::suspend_always();
}
auto final_suspend() noexcept {
std::cout << "future<T>::promise_type::final_suspend" << std::endl;
return final_awaitable();
}
void set_continuation(std::coroutine_handle<> continuation) noexcept {
_continuation = continuation;
}
void unhandled_exception() {
::new (static_cast<void*>(std::addressof(_exception))) std::exception_ptr(std::current_exception());
_failed = true;
}
template<
typename VALUE,
typename = std::enable_if_t<std::is_convertible_v<VALUE&&, T>>>
void return_value(VALUE&& value) noexcept(std::is_nothrow_constructible_v<T, VALUE&&>) {
std::cout << "future<T>::promise_type::return_void" << std::endl;
::new (static_cast<void*>(std::addressof(_value))) T(std::forward<VALUE>(value));
}
T& result() & {
std::cout << "future<T>::promise_type::result" << std::endl;
if (_failed) {
std::rethrow_exception(_exception);
} else {
return _value;
}
}
private:
std::coroutine_handle<> _continuation;
// Initially false. Set to true when either a continuation is registered
// or when the coroutine has run to completion. Whichever operation
// successfully transitions from false->true got there first.
std::atomic<bool> _continuation_state;
};
using coro_handle = std::coroutine_handle<promise_type>;
future(coro_handle handle) : _handle(handle) { }
future(future&& other) : _handle(other._handle) {
other._handle = nullptr;
}
future(const future&) = delete;
future& operator=(const future&) = delete;
~future() {
_handle.destroy();
}
bool done() {
return _handle.done();
}
// bool resume() {
// if (!_handle.done())
// _handle.resume();
// return !_handle.done();
// }
T& result() & {
std::cout << "future<T>::result" << std::endl;
return _handle.promise().result();
}
private:
coro_handle _handle;
};
template<>
class future<void> {
public:
class promise_type {
public:
class final_awaitable {
public:
bool await_ready() {
std::cout << "future<void>::promise_type::final_awaitable::await_ready" << std::endl;
return false;
}
template<typename PROMISE>
void await_suspend(std::coroutine_handle<PROMISE> coro) noexcept {
std::cout << "future<void>::promise_type::final_awaitable::await_suspend" << std::endl;
promise_type& promise = coro.promise();
// Use 'release' memory semantics in case we finish before the
// awaiter can suspend so that the awaiting thread sees our
// writes to the resulting value.
// Use 'acquire' memory semantics in case the caller registered
// the continuation before we finished. Ensure we see their write
// to m_continuation.
if (promise._continuation_state.exchange(true, std::memory_order_acq_rel)) {
promise._continuation.resume();
}
}
void await_resume() {
std::cout << "future<void>::promise_type::final_awaitable::await_resume" << std::endl;
}
};
std::exception_ptr _exception;
bool _failed = false;
promise_type() noexcept : _continuation_state(false) {}
~promise_type() {
if (_failed) {
_exception.~exception_ptr();
}
}
auto get_return_object() {
return coro_handle::from_promise(*this);
}
// custom
auto get_awaitable() {
return future(get_return_object());
}
auto initial_suspend() noexcept {
return std::suspend_always();
}
auto final_suspend() noexcept {
std::cout << "future<void>::promise_type::final_suspend" << std::endl;
return final_awaitable();
}
bool try_set_continuation(std::coroutine_handle<> continuation) {
_continuation = continuation;
return !_continuation_state.exchange(true, std::memory_order_acq_rel);
}
void unhandled_exception() {
::new (static_cast<void*>(std::addressof(_exception))) std::exception_ptr(std::current_exception());
_failed = true;
}
void return_void() {
std::cout << "future<void>::promise_type::return_void" << std::endl;
}
void result() {
std::cout << "future<void>::promise_type::result" << std::endl;
if (_failed) {
std::rethrow_exception(_exception);
}
}
private:
std::coroutine_handle<> _continuation;
// Initially false. Set to true when either a continuation is registered
// or when the coroutine has run to completion. Whichever operation
// successfully transitions from false->true got there first.
std::atomic<bool> _continuation_state;
};
using coro_handle = std::coroutine_handle<promise_type>;
class awaitable_base {
public:
coro_handle _coroutine;
awaitable_base(coro_handle coroutine) noexcept : _coroutine(coroutine) {}
bool await_ready() const noexcept {
return !_coroutine || _coroutine.done();
}
bool await_suspend(std::coroutine_handle<> awaitingCoroutine) noexcept {
// NOTE: We are using the bool-returning version of await_suspend() here
// to work around a potential stack-overflow issue if a coroutine
// awaits many synchronously-completing tasks in a loop.
//
// We first start the task by calling resume() and then conditionally
// attach the continuation if it has not already completed. This allows us
// to immediately resume the awaiting coroutine without increasing
// the stack depth, avoiding the stack-overflow problem. However, it has
// the down-side of requiring a std::atomic to arbitrate the race between
// the coroutine potentially completing on another thread concurrently
// with registering the continuation on this thread.
//
// We can eliminate the use of the std::atomic once we have access to
// coroutine_handle-returning await_suspend() on both MSVC and Clang
// as this will provide ability to suspend the awaiting coroutine and
// resume another coroutine with a guaranteed tail-call to resume().
_coroutine.resume();
return _coroutine.promise().try_set_continuation(awaitingCoroutine);
}
};
future(coro_handle handle) : _handle(handle) { }
future(future&& other) : _handle(other._handle) {
other._handle = nullptr;
}
future(const future&) = delete;
future& operator=(const future&) = delete;
~future() {
_handle.destroy();
}
bool done() {
return _handle.done();
}
void result() {
std::cout << "future<void>::result" << std::endl;
_handle.promise().result();
}
auto operator co_await() const & noexcept {
class awaitable : public awaitable_base {
public:
using awaitable_base::awaitable_base;
decltype(auto) await_resume() {
if (!_coroutine) {
// throw broken_promise{};
}
return _coroutine.promise().result();
}
};
return awaitable(_handle);
}
auto operator co_await() const && noexcept {
class awaitable : public awaitable_base {
public:
using awaitable_base::awaitable_base;
decltype(auto) await_resume() {
if (!_coroutine) {
// throw broken_promise{};
}
return std::move(_coroutine.promise()).result();
}
};
return awaitable(_handle);
}
private:
coro_handle _handle;
};
template<typename T>
class callback_future {
public:
class promise_type {
public:
class completion_awaitable {
public:
bool await_ready() const noexcept {
return false;
}
void await_suspend(std::coroutine_handle<promise_type> coroutine) const noexcept {
coroutine.promise()._callback();
}
void await_resume() noexcept {}
};
promise_type() noexcept { }
void start(std::function<void()> callback) {
_callback = callback;
coro_handle::from_promise(*this).resume();
}
auto get_return_object() {
return coro_handle::from_promise(*this);
}
// custom
auto get_awaitable() {
return callback_future(get_return_object());
}
auto initial_suspend() noexcept {
return std::suspend_always();
}
auto final_suspend() noexcept {
return completion_awaitable();
}
auto yield_value(T&& result) noexcept {
_result = std::addressof(result);
return final_suspend();
}
void return_void() noexcept {
// The coroutine should have either yielded a value or thrown
// an exception in which case it should have bypassed return_void().
// assert(false);
}
void unhandled_exception() {
_exception = std::current_exception();
}
T&& result() {
if (_exception) {
std::rethrow_exception(_exception);
} else {
return static_cast<T&&>(*_result);
}
}
private:
std::remove_reference_t<T>* _result;
std::exception_ptr _exception;
std::function<void()> _callback;
};
using coro_handle = std::coroutine_handle<promise_type>;
callback_future(coro_handle handle) : _handle(handle) { }
callback_future(callback_future&& other) : _handle(other._handle) {
other._handle = nullptr;
}
callback_future(const callback_future&) = delete;
callback_future& operator=(const callback_future&) = delete;
~callback_future() {
_handle.destroy();
}
void start(std::function<void()> callback) noexcept {
_handle.promise().start(callback);
}
decltype(auto) result() {
return _handle.promise().result();
}
private:
coro_handle _handle;
};
template<>
class callback_future<void> {
public:
class promise_type {
public:
class completion_awaitable {
public:
bool await_ready() const noexcept {
return false;
}
void await_suspend(std::coroutine_handle<promise_type> coroutine) const noexcept {
coroutine.promise()._callback();
}
void await_resume() noexcept {}
};
promise_type() noexcept { }
void start(std::function<void()> callback) {
_callback = callback;
coro_handle::from_promise(*this).resume();
}
auto get_return_object() {
return coro_handle::from_promise(*this);
}
// custom
auto get_awaitable() {
return callback_future(get_return_object());
}
auto initial_suspend() noexcept {
return std::suspend_always();
}
auto final_suspend() noexcept {
return completion_awaitable();
}
void return_void() noexcept { }
void unhandled_exception() {
_exception = std::current_exception();
}
void result() {
if (_exception) {
std::rethrow_exception(_exception);
}
}
private:
std::exception_ptr _exception;
std::function<void()> _callback;
};
using coro_handle = std::coroutine_handle<promise_type>;
callback_future(coro_handle handle) : _handle(handle) { }
callback_future(callback_future&& other) : _handle(std::move(other._handle)) {
other._handle = nullptr;
}
callback_future& operator=(callback_future&& other) {
_handle = std::move(other._handle);
other._handle = nullptr;
return *this;
}
callback_future(const callback_future&) = delete;
callback_future& operator=(const callback_future&) = delete;
~callback_future() {
if (_handle) {
_handle.destroy();
}
}
void start(std::function<void()> callback) noexcept {
std::cout << "callback_future<void>::start" << std::endl;
_handle.promise().start(callback);
}
decltype(auto) result() {
return _handle.promise().result();
}
private:
coro_handle _handle;
};
template<
typename AWAITABLE,
typename RESULT = void,
std::enable_if_t<std::is_void_v<RESULT>, int> = 0>
callback_future<void> make_callback_future(AWAITABLE&& awaitable) {
co_await std::forward<AWAITABLE>(awaitable);
}
template<typename AWAITABLE>
void callback_wait(AWAITABLE&& awaitable, std::function<void()> callback) {
std::cout << "callback_wait1" << std::endl;
auto task = reinterpret_cast<callback_future<void>*>(operator new(sizeof(callback_future<void>)));
std::cout << "callback_wait2" << std::endl;
*task = std::move(make_callback_future(std::forward<AWAITABLE>(awaitable)));
std::cout << "callback_wait3" << std::endl;
task->start([task, callback]() {
callback();
task->result();
delete task;
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment