Skip to content

Instantly share code, notes, and snippets.

@jcdickinson
Last active March 26, 2023 04:15
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save jcdickinson/ec2b93f78afc4c72ae74 to your computer and use it in GitHub Desktop.
Save jcdickinson/ec2b93f78afc4c72ae74 to your computer and use it in GitHub Desktop.
C++ Lock-Free Work Stealing Stack
#pragma once
#include <atomic>
// A lock-free stack.
// Push = single producer
// Pop = single consumer (same thread as push)
// Steal = multiple consumer
// All methods, including Push, may fail. Re-issue the request
// if that occurs (spinwait).
template<class T, size_t capacity = 131072>
class WorkStealingStack {
public:
inline WorkStealingStack() {
_top = 1;
_bottom = 1;
}
WorkStealingStack(const WorkStealingStack&) = delete;
inline ~WorkStealingStack()
{
}
// Single producer
inline bool Push(const T& item) {
auto oldtop = _top.load(std::memory_order_relaxed);
auto oldbottom = _bottom.load(std::memory_order_relaxed);
auto numtasks = oldbottom - oldtop;
if (
oldbottom > oldtop && // size_t is unsigned, validate the result is positive
numtasks >= capacity - 1) {
// The caller can decide what to do, they will probably spinwait.
return false;
}
_values[oldbottom % capacity].store(item, std::memory_order_relaxed);
_bottom.fetch_add(1, std::memory_order_release);
return true;
}
// Single consumer
inline bool Pop(T& result) {
size_t oldtop, oldbottom, newtop, newbottom, ot;
oldbottom = _bottom.fetch_sub(1, std::memory_order_release);
ot = oldtop = _top.load(std::memory_order_acquire);
newtop = oldtop + 1;
newbottom = oldbottom - 1;
// Bottom has wrapped around.
if (oldbottom < oldtop) {
_bottom.store(oldtop, std::memory_order_relaxed);
return false;
}
// The queue is empty.
if (oldbottom == oldtop) {
_bottom.fetch_add(1, std::memory_order_release);
return false;
}
// Make sure that we are not contending for the item.
if (newbottom == oldtop) {
auto ret = _values[newbottom % capacity].load(std::memory_order_relaxed);
if (!_top.compare_exchange_strong(oldtop, newtop, std::memory_order_acquire)) {
_bottom.fetch_add(1, std::memory_order_release);
return false;
}
else {
result = ret;
_bottom.store(newtop, std::memory_order_release);
return true;
}
}
// It's uncontended.
result = _values[newbottom % capacity].load(std::memory_order_acquire);
return true;
}
// Multiple consumer.
inline bool Steal(T& result) {
size_t oldtop, newtop, oldbottom;
oldtop = _top.load(std::memory_order_acquire);
oldbottom = _bottom.load(std::memory_order_relaxed);
newtop = oldtop + 1;
if (oldbottom <= oldtop)
return false;
// Make sure that we are not contending for the item.
if (!_top.compare_exchange_strong(oldtop, newtop, std::memory_order_acquire)) {
return false;
}
result = _values[oldtop % capacity].load(std::memory_order_relaxed);
return true;
}
private:
// Circular array
std::atomic<T> _values[capacity];
std::atomic<size_t> _top; // queue
std::atomic<size_t> _bottom; // stack
};
#include "stdafx.h"
#include <thread>
#include <functional>
#include <chrono>
#include "workstealingstack.h"
#include "catch.h"
using namespace std;
TEST_CASE("Work stealing stack: Single-threaded push and pop", "[wss][serial]") {
auto wss = make_unique<WorkStealingStack<int>>();
wss->Push(100);
wss->Push(200);
wss->Push(300);
wss->Push(400);
int value[5];
bool success[5];
success[0] = wss->Pop(value[0]);
success[1] = wss->Pop(value[1]);
success[2] = wss->Pop(value[2]);
success[3] = wss->Pop(value[3]);
success[4] = wss->Pop(value[4]);
REQUIRE(success[0]);
REQUIRE(success[1]);
REQUIRE(success[2]);
REQUIRE(success[3]);
REQUIRE_FALSE(success[4]);
REQUIRE(value[0] == 400);
REQUIRE(value[1] == 300);
REQUIRE(value[2] == 200);
REQUIRE(value[3] == 100);
}
TEST_CASE("Work stealing stack: Single-threaded push and steal", "[wss][serial]") {
auto wss = make_unique<WorkStealingStack<int>>();
wss->Push(100);
wss->Push(200);
wss->Push(300);
wss->Push(400);
int value[5];
bool success[5];
success[0] = wss->Steal(value[0]);
success[1] = wss->Steal(value[1]);
success[2] = wss->Steal(value[2]);
success[3] = wss->Steal(value[3]);
success[4] = wss->Steal(value[4]);
REQUIRE(success[0]);
REQUIRE(success[1]);
REQUIRE(success[2]);
REQUIRE(success[3]);
REQUIRE_FALSE(success[4]);
REQUIRE(value[0] == 100);
REQUIRE(value[1] == 200);
REQUIRE(value[2] == 300);
REQUIRE(value[3] == 400);
}
TEST_CASE("Work stealing stack: Single-threaded push, pop and steal", "[wss][serial]") {
auto wss = make_unique<WorkStealingStack<int>>();
int value[5];
bool success[5];
wss->Push(100);
wss->Push(200);
success[0] = wss->Pop(value[0]);
wss->Push(300);
success[1] = wss->Steal(value[1]);
wss->Push(400);
success[2] = wss->Steal(value[2]);
success[3] = wss->Pop(value[3]);
success[4] = wss->Steal(value[4]);
REQUIRE(success[0]);
REQUIRE(success[1]);
REQUIRE(success[2]);
REQUIRE(success[3]);
REQUIRE_FALSE(success[4]);
REQUIRE(value[0] == 200);
REQUIRE(value[1] == 100);
REQUIRE(value[2] == 300);
REQUIRE(value[3] == 400);
}
TEST_CASE("Work stealing stack: Mulithreaded one consumer one producer", "[wss][concurrent]") {
auto wss = make_unique<WorkStealingStack<int, 200>>();
auto done = false;
auto result = 0;
auto count = 0;
thread consumer([&]() {
while (!done) {
int val;
while (wss->Steal(val)) {
count++;
result += val % 101;
}
}
});
thread producer([&]() {
for (auto i = 1; i <= 10000; i++) {
while (!wss->Push(i)){}
}
this_thread::sleep_for(chrono::seconds(1));
done = true;
});
consumer.join();
producer.join();
REQUIRE(count == 10000);
REQUIRE(result == 499951);
}
TEST_CASE("Work stealing stack: Mulithreaded one consumer one producer large iteration", "[wss][concurrent]") {
auto wss = make_unique<WorkStealingStack<int, 200>>();
auto done = false;
atomic<int> result = 0;
thread consumer([&]() {
auto i = 0;
while (!done || (++i < 10000)) {
int val;
while (wss->Steal(val)) {
result.fetch_add(1);
}
}
});
thread producer([&]() {
for (auto i = 1; i <= 100000; i++) {
while (!wss->Push(i)){}
}
done = true;
});
consumer.join();
producer.join();
auto r = result.load();
REQUIRE(r == 100000);
}
TEST_CASE("Work stealing stack: Mulithreaded many consumers one producer", "[wss][concurrent]") {
auto wss = make_unique<WorkStealingStack<int, 200>>();
auto done = false;
atomic<int> result = 0;
atomic<int> count = 0;
auto consumers = new thread[50];
for (auto i = 0; i < 50; i++) {
consumers[i] = thread([&]() {
while (!done) {
int val;
while (wss->Steal(val)) {
count.fetch_add(1);
result.fetch_add(val % 101);
}
}
});
}
thread producer([&]() {
for (auto i = 1; i <= 10000; i++) {
while (!wss->Push(i)){}
}
this_thread::sleep_for(chrono::seconds(1));
done = true;
});
for (auto i = 0; i < 50; i++) {
consumers[i].join();
}
producer.join();
auto c = count.load();
auto r = result.load();
REQUIRE(c == 10000);
REQUIRE(r == 499951);
}
TEST_CASE("Work stealing stack: Mulithreaded many consumers one consuming producer", "[wss][concurrent]") {
auto wss = make_unique<WorkStealingStack<int>>();
auto done = false;
auto selfConsumed = false;
atomic<int> result = 0;
atomic<int> count1 = 0;
atomic<int> count2 = 0;
auto bb = new int[10000];
auto consumers = new thread[11];
for (auto i = 0; i < 11; i++) {
consumers[i] = thread([&]() {
while (!done) {
int val;
while (wss->Steal(val)) {
bb[val] = 1;
count1.fetch_add(1);
result.fetch_add(val % 101);
}
}
});
}
thread producer([&]() {
for (auto i = 0; i < 10000; i++) {
while (!wss->Push(i)){}
if (i % 7 == 0) {
int val;
while (wss->Pop(val)) {
bb[val] = 2;
selfConsumed = true;
count2.fetch_add(1);
result.fetch_add(val % 101);
}
}
}
this_thread::sleep_for(chrono::seconds(1));
done = true;
});
for (auto i = 0; i < 11; i++) {
consumers[i].join();
}
producer.join();
if (!selfConsumed) {
WARN("The producer did not consume any of its own items, test result may be compromised.");
}
auto c = count1.load() + count2.load();
auto r = result.load();
REQUIRE(c == 10000);
REQUIRE(r == 499950);
}
@Chaosvex
Copy link

Chaosvex commented Jul 6, 2016

Don't suppose you'd consider making this code available under a permissive license (MIT/BSD)? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment