-
-
Save lewissbaker/38ba1d8a13e4fb0906559b7aa1c413d3 to your computer and use it in GitHub Desktop.
Prototype std::task<T> implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// -*- C++ -*- | |
//===----------------------------- coroutine -----------------------------===// | |
// | |
// The LLVM Compiler Infrastructure | |
// | |
// This file is distributed under the University of Illinois Open Source | |
// License. See LICENSE.TXT for details. | |
// | |
//===----------------------------------------------------------------------===// | |
#ifndef _LIBCPP_EXPERIMENTAL_TASK | |
#define _LIBCPP_EXPERIMENTAL_TASK | |
#include <experimental/__config> | |
#include <experimental/coroutine> | |
#include <exception> | |
#include <utility> | |
#include <type_traits> | |
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) | |
#pragma GCC system_header | |
#endif | |
#ifdef _LIBCPP_HAS_NO_COROUTINES | |
#if defined(_LIBCPP_WARNING) | |
_LIBCPP_WARNING("<experimental/task> cannot be used with this compiler") | |
#else | |
#warning <experimental/task> cannot be used with this compiler | |
#endif | |
#endif | |
_LIBCPP_BEGIN_NAMESPACE_EXPERIMENTAL_COROUTINES | |
////// task<T> | |
struct __task_promise_final_awaitable { | |
_LIBCPP_INLINE_VISIBILITY | |
_LIBCPP_CONSTEXPR bool await_ready() const _NOEXCEPT { return false; } | |
template <typename _TaskPromise> | |
_LIBCPP_INLINE_VISIBILITY coroutine_handle<> | |
await_suspend(coroutine_handle<_TaskPromise> __coro) const _NOEXCEPT { | |
_LIBCPP_ASSERT( | |
__coro.promise().__continuation_, | |
"Coroutine completed without a valid continuation attached."); | |
return __coro.promise().__continuation_; | |
} | |
_LIBCPP_INLINE_VISIBILITY | |
void await_resume() const _NOEXCEPT {} | |
}; | |
class _LIBCPP_TYPE_VIS __task_promise_base { | |
using _DeallocFunc = void(void* __ptr, size_t __size) _NOEXCEPT; | |
static _LIBCPP_CONSTEXPR size_t __pad(size_t __value, | |
size_t __alignment) _NOEXCEPT { | |
return (__value + (__alignment - 1u)) & ~(__alignment - 1u); | |
} | |
static _LIBCPP_CONSTEXPR size_t | |
__get_dealloc_func_offset(size_t __frameSize) _NOEXCEPT { | |
return __pad(__frameSize, alignof(_DeallocFunc*)); | |
} | |
static _LIBCPP_CONSTEXPR size_t | |
__get_padded_frame_size(size_t __frameSize) _NOEXCEPT { | |
return __get_dealloc_func_offset(__frameSize) + sizeof(_DeallocFunc*); | |
} | |
template <typename _Alloc> | |
static _LIBCPP_CONSTEXPR size_t | |
__get_allocator_offset(size_t __frameSize) _NOEXCEPT { | |
return __pad(__get_padded_frame_size(__frameSize), alignof(_Alloc)); | |
} | |
template <typename _Alloc> | |
static _LIBCPP_CONSTEXPR size_t | |
__get_padded_frame_size_with_allocator(size_t __frameSize) _NOEXCEPT { | |
return __get_allocator_offset<_Alloc>(__frameSize) + sizeof(_Alloc); | |
} | |
_LIBCPP_INLINE_VISIBILITY | |
static _DeallocFunc*& __get_dealloc_func(void* __frameStart, | |
size_t __frameSize) _NOEXCEPT { | |
return *reinterpret_cast<_DeallocFunc**>( | |
static_cast<char*>(__frameStart) + | |
__get_dealloc_func_offset(__frameSize)); | |
} | |
template <typename _Alloc> | |
_LIBCPP_INLINE_VISIBILITY static _Alloc& | |
__get_allocator(void* __frameStart, size_t __frameSize) _NOEXCEPT { | |
return *reinterpret_cast<_Alloc*>( | |
static_cast<char*>(__frameStart) + | |
__get_allocator_offset<_Alloc>(__frameSize)); | |
} | |
public: | |
static void* operator new(size_t __size) { | |
// Allocate space for an extra pointer immediately after __size that holds | |
// the type-erased deallocation function. | |
void* __pointer = ::operator new(__get_padded_frame_size(__size)); | |
_DeallocFunc*& __deallocFunc = __get_dealloc_func(__pointer, __size); | |
__deallocFunc = [](void* __pointer, size_t __size) _NOEXCEPT { | |
::operator delete(__pointer, __get_padded_frame_size(__size)); | |
}; | |
return __pointer; | |
} | |
template <typename _Alloc, typename... _Args> | |
static void* operator new(size_t __size, allocator_arg_t, _Alloc& __allocator, | |
_Args&...) { | |
static_assert(is_same<typename _Alloc::value_type, char>::value, | |
"task<T> coroutine custom allocator must have a 'value_type' " | |
"of 'char'."); | |
void* __pointer = __allocator.allocate( | |
__get_padded_frame_size_with_allocator<_Alloc>(__size)); | |
_DeallocFunc*& __deallocFunc = __get_dealloc_func(__pointer, __size); | |
__deallocFunc = [](void* __pointer, size_t __size) _NOEXCEPT { | |
static_assert(is_nothrow_move_constructible<_Alloc>::value, | |
"task<T> coroutine custom allocator requires a noexcept " | |
"move constructor"); | |
_Alloc& __allocatorInFrame = __get_allocator<_Alloc>(__pointer, __size); | |
_Alloc __allocatorOnStack = move(__allocatorInFrame); | |
__allocatorInFrame.~_Alloc(); | |
size_t __paddedSize = | |
__get_padded_frame_size_with_allocator<_Alloc>(__size); | |
// Allocator requirements state that deallocate() must not throw. | |
// See [allocator.requirements] from C++ standard. | |
// We are relying on that here. | |
__allocatorOnStack.deallocate(static_cast<char*>(__pointer), | |
__paddedSize); | |
}; | |
// Copy the allocator into the heap frame. | |
static_assert(is_nothrow_copy_constructible<_Alloc>::value, | |
"task<T> coroutine custom allocator requires a noexcept copy " | |
"constructor"); | |
new (static_cast<void*>(_VSTD::addressof( | |
__get_allocator<_Alloc>(__pointer, __size)))) _Alloc(__allocator); | |
return __pointer; | |
} | |
_LIBCPP_INLINE_VISIBILITY | |
static void operator delete(void* __pointer, size_t __size) { | |
__get_dealloc_func(__pointer, __size)(__pointer, __size); | |
} | |
_LIBCPP_INLINE_VISIBILITY | |
suspend_always initial_suspend() const _NOEXCEPT { return {}; } | |
_LIBCPP_INLINE_VISIBILITY | |
__task_promise_final_awaitable final_suspend() _NOEXCEPT { return {}; } | |
_LIBCPP_INLINE_VISIBILITY | |
void __set_continuation(coroutine_handle<> __continuation) { | |
_LIBCPP_ASSERT(!__continuation_, "task already has a continuation"); | |
__continuation_ = __continuation; | |
} | |
private: | |
friend class __task_promise_final_awaitable; | |
coroutine_handle<> __continuation_; | |
}; | |
template <typename _Tp> | |
class _LIBCPP_TEMPLATE_VIS __task_promise final : public __task_promise_base { | |
using _Handle = coroutine_handle<__task_promise>; | |
public: | |
__task_promise() _NOEXCEPT {} | |
~__task_promise() { | |
switch (__state_) { | |
case _State::__value: | |
__value_.~_Tp(); | |
break; | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
case _State::__exception: | |
__exception_.~exception_ptr(); | |
break; | |
#endif | |
case _State::__no_value: | |
break; | |
}; | |
} | |
_LIBCPP_INLINE_VISIBILITY | |
_Handle get_return_object() _NOEXCEPT { return _Handle::from_promise(*this); } | |
void unhandled_exception() _NOEXCEPT { | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
new (static_cast<void*>(&__exception_)) exception_ptr(current_exception()); | |
__state_ = _State::__exception; | |
#else | |
_LIBCPP_ASSERT( | |
false, "task<T> coroutine unexpectedly called unhandled_exception()"); | |
#endif | |
} | |
// Only enable return_value() overload if _Tp is implicitly constructible from _Value | |
template <typename _Value, | |
enable_if_t<is_convertible<_Value, _Tp>::value, int> = 0> | |
void return_value(_Value&& __value) | |
_NOEXCEPT_((is_nothrow_constructible<_Tp, _Value>::value)) { | |
new (static_cast<void*>(_VSTD::addressof(__value_))) | |
_Tp(static_cast<_Value&&>(__value)); | |
// Only set __state_ after successfully constructing the value. | |
// If constructor throws then state will be updated by unhandled_exception(). | |
__state_ = _State::__value; | |
} | |
_Tp& __lvalue_result() { | |
__throw_if_exception(); | |
return __value_; | |
} | |
_Tp&& __rvalue_result() { | |
__throw_if_exception(); | |
return static_cast<_Tp&&>(__value_); | |
} | |
private: | |
void __throw_if_exception() { | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
if (__state_ == _State::__exception) { | |
rethrow_exception(__exception_); | |
} | |
#endif | |
} | |
enum class _State { | |
__no_value, | |
__value | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
, | |
__exception | |
#endif | |
}; | |
_State __state_ = _State::__no_value; | |
union { | |
char __empty_; | |
_Tp __value_; | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
exception_ptr __exception_; | |
#endif | |
}; | |
}; | |
template <typename _Tp> | |
class __task_promise<_Tp&> final : public __task_promise_base { | |
using _Ptr = _Tp*; | |
using _Handle = coroutine_handle<__task_promise>; | |
public: | |
__task_promise() _NOEXCEPT {} | |
~__task_promise() { | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
if (__has_exception_) { | |
__exception_.~exception_ptr(); | |
} | |
#endif | |
} | |
_LIBCPP_INLINE_VISIBILITY | |
_Handle get_return_object() _NOEXCEPT { return _Handle::from_promise(*this); } | |
void unhandled_exception() _NOEXCEPT { | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
new (static_cast<void*>(&__exception_)) exception_ptr(current_exception()); | |
__has_exception_ = true; | |
#else | |
_LIBCPP_ASSERT( | |
false, "task<T> coroutine unexpectedly called unhandled_exception()"); | |
#endif | |
} | |
void return_value(_Tp& __value) _NOEXCEPT { | |
new (static_cast<void*>(&__pointer_)) _Ptr(_VSTD::addressof(__value)); | |
} | |
_Tp& __lvalue_result() { | |
__throw_if_exception(); | |
return *__pointer_; | |
} | |
_Tp& __rvalue_result() { return __lvalue_result(); } | |
private: | |
void __throw_if_exception() { | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
if (__has_exception_) { | |
rethrow_exception(__exception_); | |
} | |
#endif | |
} | |
union { | |
char __empty_; | |
_Ptr __pointer_; | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
exception_ptr __exception_; | |
#endif | |
}; | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
bool __has_exception_ = false; | |
#endif | |
}; | |
template <> | |
class __task_promise<void> final : public __task_promise_base { | |
using _Handle = coroutine_handle<__task_promise>; | |
public: | |
_Handle get_return_object() _NOEXCEPT { return _Handle::from_promise(*this); } | |
void return_void() _NOEXCEPT {} | |
void unhandled_exception() _NOEXCEPT { | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
__exception_ = current_exception(); | |
#endif | |
} | |
void __lvalue_result() { __throw_if_exception(); } | |
void __rvalue_result() { __throw_if_exception(); } | |
private: | |
void __throw_if_exception() { | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
if (__exception_) { | |
rethrow_exception(__exception_); | |
} | |
#endif | |
} | |
#ifndef _LIBCPP_NO_EXCEPTIONS | |
exception_ptr __exception_; | |
#endif | |
}; | |
template <typename _Tp> | |
class _LIBCPP_TEMPLATE_VIS _LIBCPP_NODISCARD_AFTER_CXX17 task { | |
public: | |
using promise_type = __task_promise<_Tp>; | |
private: | |
using _Handle = coroutine_handle<__task_promise<_Tp> >; | |
class _AwaiterBase { | |
public: | |
_AwaiterBase(_Handle __coro) _NOEXCEPT : __coro_(__coro) {} | |
_LIBCPP_INLINE_VISIBILITY | |
bool await_ready() const { return __coro_.done(); } | |
_LIBCPP_INLINE_VISIBILITY | |
_Handle await_suspend(coroutine_handle<> __continuation) const { | |
__coro_.promise().__set_continuation(__continuation); | |
return __coro_; | |
} | |
protected: | |
_Handle __coro_; | |
}; | |
public: | |
_LIBCPP_INLINE_VISIBILITY | |
task(_Handle __coro) _NOEXCEPT : __coro_(__coro) {} | |
_LIBCPP_INLINE_VISIBILITY | |
task(task&& __other) _NOEXCEPT | |
: __coro_(_VSTD::exchange(__other.__coro_, {})) {} | |
task(const task&) = delete; | |
task& operator=(const task&) = delete; | |
_LIBCPP_INLINE_VISIBILITY | |
~task() { | |
if (__coro_) | |
__coro_.destroy(); | |
} | |
_LIBCPP_INLINE_VISIBILITY | |
task& operator=(task __other) _NOEXCEPT { | |
swap(__other); | |
return *this; | |
} | |
_LIBCPP_INLINE_VISIBILITY | |
void swap(task& __other) _NOEXCEPT { _VSTD::swap(__coro_, __other.__coro_); } | |
_LIBCPP_INLINE_VISIBILITY | |
auto operator co_await() & { | |
class _Awaiter : public _AwaiterBase { | |
public: | |
using _AwaiterBase::_AwaiterBase; | |
_LIBCPP_INLINE_VISIBILITY | |
decltype(auto) await_resume() { | |
return this->__coro_.promise().__lvalue_result(); | |
} | |
}; | |
_LIBCPP_ASSERT(__coro_, | |
"Undefined behaviour to co_await an invalid task<T>"); | |
return _Awaiter{__coro_}; | |
} | |
_LIBCPP_INLINE_VISIBILITY | |
auto operator co_await() && { | |
class _Awaiter : public _AwaiterBase { | |
public: | |
using _AwaiterBase::_AwaiterBase; | |
_LIBCPP_INLINE_VISIBILITY | |
decltype(auto) await_resume() { | |
return this->__coro_.promise().__rvalue_result(); | |
} | |
}; | |
_LIBCPP_ASSERT(__coro_, | |
"Undefined behaviour to co_await an invalid task<T>"); | |
return _Awaiter{__coro_}; | |
} | |
private: | |
_Handle __coro_; | |
}; | |
_LIBCPP_END_NAMESPACE_EXPERIMENTAL_COROUTINES | |
#endif | |
//////////////////////// | |
// Tests for task<T> | |
using namespace std::experimental; | |
task<void> f1() { co_return; } | |
task<void> f2() { co_await f1(); } | |
task<void> f3() { | |
task<void> t = f1(); | |
co_await t; | |
} | |
task<int> g1() { co_return 123; } | |
task<int> g2() { co_return co_await g1() + 100; } | |
task<int> g3() { | |
// co_await rvalue returns rvalue reference to result. | |
{ | |
// BUG: This will leave x as a dangling rvalue reference. | |
// This code is not intended to be executed, it's just here | |
// to check the return type of co_await'ing a task<T>&&. | |
decltype(auto) x = co_await g1(); | |
static_assert(std::is_same<decltype(x), int&&>::value); | |
} | |
// co_await lvalue returns lvalue reference to result | |
task<int> t = g1(); | |
decltype(auto) y = co_await t; | |
static_assert(std::is_same_v<decltype(y), int&>); | |
co_return y * 2; | |
} | |
class MoveOnly { | |
public: | |
MoveOnly(); | |
MoveOnly(const MoveOnly&) = delete; | |
MoveOnly(MoveOnly&&) noexcept; | |
~MoveOnly(); | |
int get() const; | |
private: | |
void* _data; | |
}; | |
task<MoveOnly> h1(bool x) { | |
if (x) { | |
MoveOnly value; | |
co_return std::move(value); | |
} | |
co_return MoveOnly{}; | |
} | |
task<int> h2() { | |
auto x = co_await h1(true); | |
co_return x.get(); | |
} | |
static int x; | |
static int y; | |
task<int&> h3(bool cond) { co_return cond ? x : y; } | |
task<void> h3consumer() { | |
int& result = co_await h3(true); | |
result = 32; | |
} | |
// custom allocator tests | |
template <typename Allocator> | |
task<void> a1(std::allocator_arg_t, Allocator allocator, bool x) { | |
co_return; | |
} | |
class my_allocator { | |
public: | |
using value_type = char; | |
my_allocator(); | |
my_allocator(const my_allocator&) noexcept; | |
my_allocator(my_allocator&&) noexcept; | |
char* allocate(size_t n); | |
void deallocate(char* p, size_t n); | |
private: | |
void* state; | |
}; | |
task<void> a2() { | |
my_allocator alloc; | |
a1(std::allocator_arg, alloc, true); | |
co_return; | |
} | |
task<void> a2a() { | |
my_allocator alloc; | |
return a1(std::allocator_arg, alloc, true); | |
} | |
task<void> a3() { | |
co_await a1(std::allocator_arg, std::allocator<char>{}, false); | |
} | |
task<void> a3a() { | |
return a1(std::allocator_arg, std::allocator<char>{}, false); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment