Skip to content

Instantly share code, notes, and snippets.

@lewissbaker lewissbaker/task.hpp Secret
Last active May 1, 2018

Embed
What would you like to do?
Prototype std::task<T> implementation
// -*- C++ -*-
//===----------------------------- coroutine -----------------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
#ifndef _LIBCPP_EXPERIMENTAL_TASK
#define _LIBCPP_EXPERIMENTAL_TASK
#include <experimental/__config>
#include <experimental/coroutine>
#include <exception>
#include <utility>
#include <type_traits>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
#pragma GCC system_header
#endif
#ifdef _LIBCPP_HAS_NO_COROUTINES
#if defined(_LIBCPP_WARNING)
_LIBCPP_WARNING("<experimental/task> cannot be used with this compiler")
#else
#warning <experimental/task> cannot be used with this compiler
#endif
#endif
_LIBCPP_BEGIN_NAMESPACE_EXPERIMENTAL_COROUTINES
////// task<T>
struct __task_promise_final_awaitable {
_LIBCPP_INLINE_VISIBILITY
_LIBCPP_CONSTEXPR bool await_ready() const _NOEXCEPT { return false; }
template <typename _TaskPromise>
_LIBCPP_INLINE_VISIBILITY coroutine_handle<>
await_suspend(coroutine_handle<_TaskPromise> __coro) const _NOEXCEPT {
_LIBCPP_ASSERT(
__coro.promise().__continuation_,
"Coroutine completed without a valid continuation attached.");
return __coro.promise().__continuation_;
}
_LIBCPP_INLINE_VISIBILITY
void await_resume() const _NOEXCEPT {}
};
class _LIBCPP_TYPE_VIS __task_promise_base {
using _DeallocFunc = void(void* __ptr, size_t __size) _NOEXCEPT;
static _LIBCPP_CONSTEXPR size_t __pad(size_t __value,
size_t __alignment) _NOEXCEPT {
return (__value + (__alignment - 1u)) & ~(__alignment - 1u);
}
static _LIBCPP_CONSTEXPR size_t
__get_dealloc_func_offset(size_t __frameSize) _NOEXCEPT {
return __pad(__frameSize, alignof(_DeallocFunc*));
}
static _LIBCPP_CONSTEXPR size_t
__get_padded_frame_size(size_t __frameSize) _NOEXCEPT {
return __get_dealloc_func_offset(__frameSize) + sizeof(_DeallocFunc*);
}
template <typename _Alloc>
static _LIBCPP_CONSTEXPR size_t
__get_allocator_offset(size_t __frameSize) _NOEXCEPT {
return __pad(__get_padded_frame_size(__frameSize), alignof(_Alloc));
}
template <typename _Alloc>
static _LIBCPP_CONSTEXPR size_t
__get_padded_frame_size_with_allocator(size_t __frameSize) _NOEXCEPT {
return __get_allocator_offset<_Alloc>(__frameSize) + sizeof(_Alloc);
}
_LIBCPP_INLINE_VISIBILITY
static _DeallocFunc*& __get_dealloc_func(void* __frameStart,
size_t __frameSize) _NOEXCEPT {
return *reinterpret_cast<_DeallocFunc**>(
static_cast<char*>(__frameStart) +
__get_dealloc_func_offset(__frameSize));
}
template <typename _Alloc>
_LIBCPP_INLINE_VISIBILITY static _Alloc&
__get_allocator(void* __frameStart, size_t __frameSize) _NOEXCEPT {
return *reinterpret_cast<_Alloc*>(
static_cast<char*>(__frameStart) +
__get_allocator_offset<_Alloc>(__frameSize));
}
public:
static void* operator new(size_t __size) {
// Allocate space for an extra pointer immediately after __size that holds
// the type-erased deallocation function.
void* __pointer = ::operator new(__get_padded_frame_size(__size));
_DeallocFunc*& __deallocFunc = __get_dealloc_func(__pointer, __size);
__deallocFunc = [](void* __pointer, size_t __size) _NOEXCEPT {
::operator delete(__pointer, __get_padded_frame_size(__size));
};
return __pointer;
}
template <typename _Alloc, typename... _Args>
static void* operator new(size_t __size, allocator_arg_t, _Alloc& __allocator,
_Args&...) {
static_assert(is_same<typename _Alloc::value_type, char>::value,
"task<T> coroutine custom allocator must have a 'value_type' "
"of 'char'.");
void* __pointer = __allocator.allocate(
__get_padded_frame_size_with_allocator<_Alloc>(__size));
_DeallocFunc*& __deallocFunc = __get_dealloc_func(__pointer, __size);
__deallocFunc = [](void* __pointer, size_t __size) _NOEXCEPT {
static_assert(is_nothrow_move_constructible<_Alloc>::value,
"task<T> coroutine custom allocator requires a noexcept "
"move constructor");
_Alloc& __allocatorInFrame = __get_allocator<_Alloc>(__pointer, __size);
_Alloc __allocatorOnStack = move(__allocatorInFrame);
__allocatorInFrame.~_Alloc();
size_t __paddedSize =
__get_padded_frame_size_with_allocator<_Alloc>(__size);
// Allocator requirements state that deallocate() must not throw.
// See [allocator.requirements] from C++ standard.
// We are relying on that here.
__allocatorOnStack.deallocate(static_cast<char*>(__pointer),
__paddedSize);
};
// Copy the allocator into the heap frame.
static_assert(is_nothrow_copy_constructible<_Alloc>::value,
"task<T> coroutine custom allocator requires a noexcept copy "
"constructor");
new (static_cast<void*>(_VSTD::addressof(
__get_allocator<_Alloc>(__pointer, __size)))) _Alloc(__allocator);
return __pointer;
}
_LIBCPP_INLINE_VISIBILITY
static void operator delete(void* __pointer, size_t __size) {
__get_dealloc_func(__pointer, __size)(__pointer, __size);
}
_LIBCPP_INLINE_VISIBILITY
suspend_always initial_suspend() const _NOEXCEPT { return {}; }
_LIBCPP_INLINE_VISIBILITY
__task_promise_final_awaitable final_suspend() _NOEXCEPT { return {}; }
_LIBCPP_INLINE_VISIBILITY
void __set_continuation(coroutine_handle<> __continuation) {
_LIBCPP_ASSERT(!__continuation_, "task already has a continuation");
__continuation_ = __continuation;
}
private:
friend class __task_promise_final_awaitable;
coroutine_handle<> __continuation_;
};
template <typename _Tp>
class _LIBCPP_TEMPLATE_VIS __task_promise final : public __task_promise_base {
using _Handle = coroutine_handle<__task_promise>;
public:
__task_promise() _NOEXCEPT {}
~__task_promise() {
switch (__state_) {
case _State::__value:
__value_.~_Tp();
break;
#ifndef _LIBCPP_NO_EXCEPTIONS
case _State::__exception:
__exception_.~exception_ptr();
break;
#endif
case _State::__no_value:
break;
};
}
_LIBCPP_INLINE_VISIBILITY
_Handle get_return_object() _NOEXCEPT { return _Handle::from_promise(*this); }
void unhandled_exception() _NOEXCEPT {
#ifndef _LIBCPP_NO_EXCEPTIONS
new (static_cast<void*>(&__exception_)) exception_ptr(current_exception());
__state_ = _State::__exception;
#else
_LIBCPP_ASSERT(
false, "task<T> coroutine unexpectedly called unhandled_exception()");
#endif
}
// Only enable return_value() overload if _Tp is implicitly constructible from _Value
template <typename _Value,
enable_if_t<is_convertible<_Value, _Tp>::value, int> = 0>
void return_value(_Value&& __value)
_NOEXCEPT_((is_nothrow_constructible<_Tp, _Value>::value)) {
new (static_cast<void*>(_VSTD::addressof(__value_)))
_Tp(static_cast<_Value&&>(__value));
// Only set __state_ after successfully constructing the value.
// If constructor throws then state will be updated by unhandled_exception().
__state_ = _State::__value;
}
_Tp& __lvalue_result() {
__throw_if_exception();
return __value_;
}
_Tp&& __rvalue_result() {
__throw_if_exception();
return static_cast<_Tp&&>(__value_);
}
private:
void __throw_if_exception() {
#ifndef _LIBCPP_NO_EXCEPTIONS
if (__state_ == _State::__exception) {
rethrow_exception(__exception_);
}
#endif
}
enum class _State {
__no_value,
__value
#ifndef _LIBCPP_NO_EXCEPTIONS
,
__exception
#endif
};
_State __state_ = _State::__no_value;
union {
char __empty_;
_Tp __value_;
#ifndef _LIBCPP_NO_EXCEPTIONS
exception_ptr __exception_;
#endif
};
};
template <typename _Tp>
class __task_promise<_Tp&> final : public __task_promise_base {
using _Ptr = _Tp*;
using _Handle = coroutine_handle<__task_promise>;
public:
__task_promise() _NOEXCEPT {}
~__task_promise() {
#ifndef _LIBCPP_NO_EXCEPTIONS
if (__has_exception_) {
__exception_.~exception_ptr();
}
#endif
}
_LIBCPP_INLINE_VISIBILITY
_Handle get_return_object() _NOEXCEPT { return _Handle::from_promise(*this); }
void unhandled_exception() _NOEXCEPT {
#ifndef _LIBCPP_NO_EXCEPTIONS
new (static_cast<void*>(&__exception_)) exception_ptr(current_exception());
__has_exception_ = true;
#else
_LIBCPP_ASSERT(
false, "task<T> coroutine unexpectedly called unhandled_exception()");
#endif
}
void return_value(_Tp& __value) _NOEXCEPT {
new (static_cast<void*>(&__pointer_)) _Ptr(_VSTD::addressof(__value));
}
_Tp& __lvalue_result() {
__throw_if_exception();
return *__pointer_;
}
_Tp& __rvalue_result() { return __lvalue_result(); }
private:
void __throw_if_exception() {
#ifndef _LIBCPP_NO_EXCEPTIONS
if (__has_exception_) {
rethrow_exception(__exception_);
}
#endif
}
union {
char __empty_;
_Ptr __pointer_;
#ifndef _LIBCPP_NO_EXCEPTIONS
exception_ptr __exception_;
#endif
};
#ifndef _LIBCPP_NO_EXCEPTIONS
bool __has_exception_ = false;
#endif
};
template <>
class __task_promise<void> final : public __task_promise_base {
using _Handle = coroutine_handle<__task_promise>;
public:
_Handle get_return_object() _NOEXCEPT { return _Handle::from_promise(*this); }
void return_void() _NOEXCEPT {}
void unhandled_exception() _NOEXCEPT {
#ifndef _LIBCPP_NO_EXCEPTIONS
__exception_ = current_exception();
#endif
}
void __lvalue_result() { __throw_if_exception(); }
void __rvalue_result() { __throw_if_exception(); }
private:
void __throw_if_exception() {
#ifndef _LIBCPP_NO_EXCEPTIONS
if (__exception_) {
rethrow_exception(__exception_);
}
#endif
}
#ifndef _LIBCPP_NO_EXCEPTIONS
exception_ptr __exception_;
#endif
};
template <typename _Tp>
class _LIBCPP_TEMPLATE_VIS _LIBCPP_NODISCARD_AFTER_CXX17 task {
public:
using promise_type = __task_promise<_Tp>;
private:
using _Handle = coroutine_handle<__task_promise<_Tp> >;
class _AwaiterBase {
public:
_AwaiterBase(_Handle __coro) _NOEXCEPT : __coro_(__coro) {}
_LIBCPP_INLINE_VISIBILITY
bool await_ready() const { return __coro_.done(); }
_LIBCPP_INLINE_VISIBILITY
_Handle await_suspend(coroutine_handle<> __continuation) const {
__coro_.promise().__set_continuation(__continuation);
return __coro_;
}
protected:
_Handle __coro_;
};
public:
_LIBCPP_INLINE_VISIBILITY
task(_Handle __coro) _NOEXCEPT : __coro_(__coro) {}
_LIBCPP_INLINE_VISIBILITY
task(task&& __other) _NOEXCEPT
: __coro_(_VSTD::exchange(__other.__coro_, {})) {}
task(const task&) = delete;
task& operator=(const task&) = delete;
_LIBCPP_INLINE_VISIBILITY
~task() {
if (__coro_)
__coro_.destroy();
}
_LIBCPP_INLINE_VISIBILITY
task& operator=(task __other) _NOEXCEPT {
swap(__other);
return *this;
}
_LIBCPP_INLINE_VISIBILITY
void swap(task& __other) _NOEXCEPT { _VSTD::swap(__coro_, __other.__coro_); }
_LIBCPP_INLINE_VISIBILITY
auto operator co_await() & {
class _Awaiter : public _AwaiterBase {
public:
using _AwaiterBase::_AwaiterBase;
_LIBCPP_INLINE_VISIBILITY
decltype(auto) await_resume() {
return this->__coro_.promise().__lvalue_result();
}
};
_LIBCPP_ASSERT(__coro_,
"Undefined behaviour to co_await an invalid task<T>");
return _Awaiter{__coro_};
}
_LIBCPP_INLINE_VISIBILITY
auto operator co_await() && {
class _Awaiter : public _AwaiterBase {
public:
using _AwaiterBase::_AwaiterBase;
_LIBCPP_INLINE_VISIBILITY
decltype(auto) await_resume() {
return this->__coro_.promise().__rvalue_result();
}
};
_LIBCPP_ASSERT(__coro_,
"Undefined behaviour to co_await an invalid task<T>");
return _Awaiter{__coro_};
}
private:
_Handle __coro_;
};
_LIBCPP_END_NAMESPACE_EXPERIMENTAL_COROUTINES
#endif
////////////////////////
// Tests for task<T>
using namespace std::experimental;
task<void> f1() { co_return; }
task<void> f2() { co_await f1(); }
task<void> f3() {
task<void> t = f1();
co_await t;
}
task<int> g1() { co_return 123; }
task<int> g2() { co_return co_await g1() + 100; }
task<int> g3() {
// co_await rvalue returns rvalue reference to result.
{
// BUG: This will leave x as a dangling rvalue reference.
// This code is not intended to be executed, it's just here
// to check the return type of co_await'ing a task<T>&&.
decltype(auto) x = co_await g1();
static_assert(std::is_same<decltype(x), int&&>::value);
}
// co_await lvalue returns lvalue reference to result
task<int> t = g1();
decltype(auto) y = co_await t;
static_assert(std::is_same_v<decltype(y), int&>);
co_return y * 2;
}
class MoveOnly {
public:
MoveOnly();
MoveOnly(const MoveOnly&) = delete;
MoveOnly(MoveOnly&&) noexcept;
~MoveOnly();
int get() const;
private:
void* _data;
};
task<MoveOnly> h1(bool x) {
if (x) {
MoveOnly value;
co_return std::move(value);
}
co_return MoveOnly{};
}
task<int> h2() {
auto x = co_await h1(true);
co_return x.get();
}
static int x;
static int y;
task<int&> h3(bool cond) { co_return cond ? x : y; }
task<void> h3consumer() {
int& result = co_await h3(true);
result = 32;
}
// custom allocator tests
template <typename Allocator>
task<void> a1(std::allocator_arg_t, Allocator allocator, bool x) {
co_return;
}
class my_allocator {
public:
using value_type = char;
my_allocator();
my_allocator(const my_allocator&) noexcept;
my_allocator(my_allocator&&) noexcept;
char* allocate(size_t n);
void deallocate(char* p, size_t n);
private:
void* state;
};
task<void> a2() {
my_allocator alloc;
a1(std::allocator_arg, alloc, true);
co_return;
}
task<void> a2a() {
my_allocator alloc;
return a1(std::allocator_arg, alloc, true);
}
task<void> a3() {
co_await a1(std::allocator_arg, std::allocator<char>{}, false);
}
task<void> a3a() {
return a1(std::allocator_arg, std::allocator<char>{}, false);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.