Created
May 30, 2017 05:49
-
-
Save DrPizza/e218fee053df5a974514c8213a91ffa4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
// Copyright(c) 2017 Peter Bright | |
// | |
// This software is provided 'as-is', without any express or implied | |
// warranty. In no event will the authors be held liable for any damages | |
// arising from the use of this software. | |
// | |
// Permission is granted to anyone to use this software for any purpose, | |
// including commercial applications, and to alter it and redistribute it | |
// freely, subject to the following restrictions: | |
// | |
// 1. The origin of this software must not be misrepresented; you must not | |
// claim that you wrote the original software. If you use this software | |
// in a product, an acknowledgment in the product documentation would be | |
// appreciated but is not required. | |
// 2. Altered source versions must be plainly marked as such, and must not be | |
// misrepresented as being the original software. | |
// 3. This notice may not be removed or altered from any source distribution. | |
#include <atomic> | |
#include <algorithm> | |
#include <optional> | |
#include <memory> | |
#include <type_traits> | |
#include <cassert> | |
// queue from http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.170.1097&rep=rep1&type=pdf | |
// atomics taken from http://www.di.ens.fr/~zappa/readings/ppopp13.pdf | |
template<typename T> | |
struct work_stealing_queue { | |
static_assert(std::is_default_constructible_v<T>, "T must be DefaultConstructible"); | |
static_assert(std::is_trivially_copyable_v<T>, "T must be TriviallyCopyable"); | |
work_stealing_queue() : top(0), bottom(0), arr(circular_array::make_circular_array(5)) { | |
} | |
~work_stealing_queue() { | |
circular_array::destroy_circular_array(arr.load(std::memory_order_relaxed)); | |
} | |
void push(const T& elem) { | |
size_t b = bottom.load(std::memory_order_relaxed); | |
size_t t = top.load(std::memory_order_acquire); | |
circular_array* a = arr.load(std::memory_order_relaxed); | |
if(b - t > a->size() - 1) { | |
a = a->grow(t, b); | |
arr.store(a, std::memory_order_relaxed); | |
} | |
a->put(b, elem); | |
std::atomic_thread_fence(std::memory_order_release); | |
bottom.store(b + 1, std::memory_order_relaxed); | |
} | |
std::optional<T> pop() { | |
size_t b = bottom.load(std::memory_order_relaxed) - 1; | |
circular_array* a = arr.load(std::memory_order_relaxed); | |
bottom.store(b, std::memory_order_relaxed); | |
std::atomic_thread_fence(std::memory_order_seq_cst); | |
size_t t = top.load(std::memory_order_relaxed); | |
if(t <= b) { | |
std::optional<T> x = a->get(b); | |
if(t == b) { | |
if(!top.compare_exchange_strong(t, t + 1, std::memory_order_seq_cst, std::memory_order_relaxed)) { | |
x = std::nullopt; | |
} | |
bottom.store(b + 1, std::memory_order_relaxed); | |
} | |
return x; | |
} else { | |
bottom.store(b + 1, std::memory_order_relaxed); | |
return std::nullopt; | |
} | |
} | |
std::pair<bool, std::optional<T>> steal() { | |
size_t t = top.load(std::memory_order_acquire); | |
std::atomic_thread_fence(std::memory_order_seq_cst); | |
size_t b = bottom.load(std::memory_order_acquire); | |
if(t < b) { | |
circular_array* a = arr.load(std::memory_order_acquire); | |
std::optional<T> x = a->get(t); | |
if(!top.compare_exchange_strong(t, t + 1, std::memory_order_seq_cst, std::memory_order_relaxed)) { | |
return std::make_pair(false, std::nullopt); | |
} | |
return std::make_pair(true, x); | |
} | |
return std::make_pair(true, std::nullopt); | |
} | |
private: | |
using atomic_value = std::atomic<T>; | |
struct circular_array; | |
struct circular_array_data { | |
circular_array_data(size_t log_size_, void(*deleter)(circular_array*)) : log_size(log_size_), previous(nullptr, deleter) { | |
} | |
size_t log_size; | |
using chunk_ptr = std::unique_ptr<circular_array, void(*)(circular_array*)>; | |
chunk_ptr previous; | |
}; | |
struct | |
alignas(atomic_value) | |
alignas(circular_array_data) | |
circular_array : circular_array_data { | |
circular_array(size_t log_size_) : circular_array_data(log_size_, &destroy_circular_array) { | |
atomic_value* elts = elements(); | |
size_t sz = size(); | |
for(size_t i = 0; i < sz; ++i) { | |
new (&elts[i]) atomic_value(T{}); | |
} | |
} | |
~circular_array() { | |
atomic_value* elts = elements(); | |
size_t sz = size(); | |
for(size_t i = 0; i < sz; ++i) { | |
elts[sz - i - 1].~atomic_value(); | |
} | |
} | |
size_t size() const { | |
return 1ull << this->log_size; | |
} | |
void put(size_t i, const T& v) { | |
elements()[i & (size() - 1)].store(v, std::memory_order_relaxed); | |
} | |
T get(size_t i) { | |
return elements()[i & (size() - 1)].load(std::memory_order_relaxed); | |
} | |
circular_array* grow(size_t t, size_t b) { | |
circular_array* a = make_circular_array(this->log_size + 1); | |
a->previous.reset(this); | |
for(size_t i = t; i != b; ++i) { | |
a->put(i, get(i)); | |
} | |
return a; | |
} | |
constexpr static size_t get_allocation_size(size_t log_size) { | |
return sizeof(circular_array) + ((1ull << log_size) * sizeof(atomic_value)); | |
} | |
static circular_array* make_circular_array(size_t log_size) { | |
const size_t allocation_size = get_allocation_size(log_size); | |
void* raw_memory = ::operator new(allocation_size); | |
return new(raw_memory) circular_array(log_size); | |
} | |
static void destroy_circular_array(circular_array* c) { | |
c->~circular_array(); | |
::operator delete(c); | |
} | |
atomic_value* elements() { | |
assert(reinterpret_cast<atomic_value*>(this + 1) == &elements_[0]); | |
return reinterpret_cast<atomic_value*>(this + 1); | |
//return reinterpret_cast<atomic_value*>(reinterpret_cast<byte*>(this) + sizeof(*this)); | |
} | |
#ifdef _DEBUG | |
#pragma warning(suppress: 4200) | |
atomic_value elements_[]; | |
#endif | |
}; | |
std::atomic<size_t> top; | |
std::atomic<size_t> bottom; | |
std::atomic<circular_array*> arr; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment