Skip to content

Instantly share code, notes, and snippets.

@JATothrim
Last active June 21, 2019 14:49
Show Gist options
  • Save JATothrim/97aef4c46f5e4391eb07886732493078 to your computer and use it in GitHub Desktop.
Save JATothrim/97aef4c46f5e4391eb07886732493078 to your computer and use it in GitHub Desktop.
Lock-free slist
#include <functional>
#include <algorithm>
#include <numeric>
#include <memory>
#include <atomic>
#include <thread>
#include <iostream>
#include <sstream>
#include <random>
#include <array>
/**
* \struct tagged_ptr<> special pointer for lock-free algorithms.
* \brief tagged_ptr<> embeds an ABA tag counter into the ptr.
* \note tagged_ptr<> has strict aligment requirments:
* #) std::atomic< tagged_ptr< X > > objects should be aligned to cpu cache-line width.
* #) All tagged_ptr< X > objects must aligned to minimum of sizeof(tagged_ptr< X >).
*/
template<typename T>
struct alignas(16) tagged_ptr {
typedef typename std::conditional<sizeof(T*) == sizeof(uint32_t), uint32_t, uint16_t>::type tag_type;
union {
T* ptr;
struct {
char _padd[sizeof(T*) - (sizeof(T*) == sizeof(uint32_t) ? 0 : sizeof(uint16_t))];
tag_type tag;
};
};
// Tagged ptr must be trivially copyable so don't define any ctors.
// Get pointer value.
T * get() const {
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> tmp;
tmp.ptr = ptr;
tmp.tag = 0U;
return reinterpret_cast<T*>(tmp.ptr);
}
// the tagged_ptr can implictly converted to regular pointer.
operator T*() const {
return get();
}
// check if tagged_ptr is null.
operator bool() const {
return get() != nullptr;
}
// tagged ptr can also be dereferenced..
T & operator*() const {
return *get();
}
T * operator->() const {
return get();
}
// array indexed.
T & operator[](size_t nth) const {
return get()[nth];
}
// Compare tagged ptr with tag value.
bool operator==(const tagged_ptr<T> & rhs) const {
return this->ptr == rhs.ptr;
}
bool operator!=(const tagged_ptr<T> & rhs) const {
return this->ptr != rhs.ptr;
}
// Increment the ABA-tag and store new ptr.
tagged_ptr<T> operator()(T * ptr) const {
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> aba;
aba.ptr = ptr;
aba.tag = tag + 1;
return aba;
}
tagged_ptr<T> operator()(tagged_ptr<T> aba) const {
aba.tag = tag + 1;
return aba;
}
// Increment the ABA-tag and return ptr.
tagged_ptr<T> operator()() const {
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> aba;
aba.ptr = ptr;
aba.tag = tag + 1;
return aba;
}
// return tag value.
unsigned int get_tag() const {
return tag;
}
// Check if all address bits are set
// (a.ka void/invalid ptr) ignoring the tag bits.
bool is_void() const {
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> c, v;
c.ptr = ptr;
c.tag = 0U;
v.ptr = reinterpret_cast<T*>(~0UL);
v.tag = 0U;
return c.ptr == v.ptr;
}
// Set all address bits (make a void/invalid ptr)
tagged_ptr<T> set_void() {
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> aba;
aba.ptr = reinterpret_cast<T*>(~0UL);
aba.tag = tag;
return *this;
}
};
static_assert(sizeof(tagged_ptr<int>) >= sizeof(void*), "tagged_ptr<X> is miscompled!");
/**
* make tagged_ptr<T> from ptr with zero as tag value.
*/
template<typename T>
inline tagged_ptr<T> ptr_tag(T * ptr)
{
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> tp;
tp.ptr = ptr;
tp.tag = 0U;
return tp;
}
/**
* make tagged_ptr<T> from ptr with specified tag value.
*/
template<typename T>
inline tagged_ptr<T> ptr_tag(T * ptr, typename tagged_ptr<T>::tag_type tag)
{
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> tp;
tp.ptr = ptr;
tp.tag = tag;
return tp;
}
/**
* make a null tagged_ptr<T> with zero tag.
*/
template<typename T>
inline tagged_ptr<T> nullptr_tag()
{
return ptr_tag<T>(nullptr);
}
/**
* make void/invalid tagged_ptr<T> with zero tag
*/
template<typename T>
inline tagged_ptr<T> voidptr_tag()
{
return ptr_tag<T>(reinterpret_cast<T*>(~0UL));
}
/**
* Lock-free list that allows push at front
* and removal at any position.
* Assumes list user handles allocation
* and safe reclamation of list nodes.
*/
class lockfree_slist {
public:
struct list_node {
std::atomic<tagged_ptr<list_node>> next;
int data;
};
// List head ptr
std::atomic<tagged_ptr<list_node>> head;
lockfree_slist() : head(nullptr_tag<list_node>()) {}
// Push new value at list front.
void push_front(list_node * node) {
auto hdr = head.load(std::memory_order_acquire);
do {
node->next.store(hdr);
} while(!head.compare_exchange_weak(hdr, hdr(node)));
}
// Try Find first node with criteria
// predicate arguments are before-ptr and dereferenced-ptr
list_node* find(const std::function<bool(std::atomic<tagged_ptr<list_node>>*,
tagged_ptr<list_node>)> & func) {
auto ptr = &head;
auto itr = ptr->load(std::memory_order_acquire);
while(itr) {
if(itr.is_void()) {
// Thread interfered iteration.
ptr = &head;
itr = ptr->load(std::memory_order_acquire);
} else if(func(ptr, itr)) {
return itr.get();
} else {
// Access next ptr
ptr = &itr->next;
itr = ptr->load(std::memory_order_acquire);
}
}
return nullptr;
}
// Get approximate size of the list.
// return std::numeric_limits<size_t>::max() if thread
// interfered.
size_t count() {
size_t count = 0;
auto ptr = &head;
auto itr = ptr->load(std::memory_order_acquire);
while(itr) {
if(itr.is_void()) {
// Thread interfered iteration.
ptr = &head;
itr = ptr->load(std::memory_order_acquire);
count = 0;
} else {
++count;
// Access next ptr
ptr = &itr->next;
itr = ptr->load(std::memory_order_acquire);
}
}
return count;
}
// Pop item at list front.
list_node * pop_front() {
auto node = head.load(std::memory_order_acquire);
while(node) {
auto after = node->next.load(std::memory_order_acquire);
if(after.is_void()) {
node = head.load(std::memory_order_acquire);
} else {
// Attempt pop
list_node * tmp = node.get();
if(head.compare_exchange_weak(node, node(after))) {
while(!tmp->next.compare_exchange_weak(after, after().set_void()))
;
return tmp;
}
}
}
return nullptr;
}
// Try erase node from the list
// @return false if the node was not linked in the list.
bool erase(list_node * node) {
std::atomic<tagged_ptr<list_node>>* before;
tagged_ptr<list_node> itr, after;
for(;;) {
// Find previous (or head) before-node-ptr
before = &head;
itr = before->load(std::memory_order_acquire);
while(itr) {
if(itr.get() == node) {
break;
} else if(itr.is_void()) {
// Thread interfered iteration.
before = &head;
itr = before->load(std::memory_order_acquire);
} else {
// Access next ptr
before = &itr->next;
itr = before->load(std::memory_order_acquire);
}
}
after = node->next.load(std::memory_order_acquire);
if(after.is_void() || !itr) {
return false;
}
// list now:before --> node/itr --> after --> ... --> null
// Point before-ptr to after. (set head or previous node's next ptr)
if(before->compare_exchange_strong(itr, itr(after))) {
// list now:before --> after --> ... --> null, node --> after --> ... --> null
// Set node->next to invalid ptr.
while(!node->next.compare_exchange_weak(after, after().set_void()))
;
// list now:before --> after --> ... --> null, node --> invalid
return true;
}
// If *before changed while trying to update it to after, retry search.
}
}
};
/**
* Test code
*/
// Compare ist node value with k
bool cmpval(std::atomic<tagged_ptr<lockfree_slist::list_node>>* ptr,
tagged_ptr<lockfree_slist::list_node> ref, int k)
{
return ref->data == k;
}
std::atomic<int> insertions(0);
std::atomic<int> insertion_latest(0);
std::atomic<int> erasures(0);
std::atomic<int> erasures_latest(0);
// Lock-free list instance.
lockfree_slist list;
int main() {
const int MAX_LIST_SIZE = 32;
// List nodes are simply stored in array
std::array<lockfree_slist::list_node,MAX_LIST_SIZE> nodes;
// pusher and remover recycle nodes using this array.
std::array<std::atomic<lockfree_slist::list_node*>,MAX_LIST_SIZE> ptrs;
std::atomic<bool> run(true);
for(int i = 0; i < nodes.size(); ++i) {
ptrs[i] = &nodes[i];
nodes[i].next.store(voidptr_tag<lockfree_slist::list_node>());
nodes[i].data = i;
}
// Pusher tries to maintain set of ints
// that has no gaps.
auto pusher = [&run,&nodes,&ptrs]() {
while(run) {
lockfree_slist::list_node * nn = nullptr;
for(int i = 0; i < nodes.size(); ++i) {
nn = ptrs[i].exchange(nullptr);
if(nn) {
break;
}
}
++insertions;
insertion_latest = nn->data;
list.push_front(nn);
}
};
// Remover mixes if the list gets too long by erasing
// nodes at random points.
auto remover = [&run,&ptrs]() {
std::random_device rd;
std::mt19937 rng;
rng.seed(rd());
while(run) {
// Find random item from the list.
int val = std::uniform_int_distribution<int>(0,ptrs.size())(rng);
auto node = list.find(std::bind(cmpval,std::placeholders::_1, std::placeholders::_2, val));
// Try erase the node.
if(node && list.erase(node)) {
++erasures;
erasures_latest = node->data;
// Recycle node.
lockfree_slist::list_node * nn;
for(int i = 0; i < ptrs.size(); ++i) {
nn = nullptr;
if(ptrs[i].compare_exchange_strong(nn, node))
break;
}
}
}
};
std::thread push_thr(pusher);
std::thread push_thr2(pusher);
std::thread eraser_thr(remover);
std::thread eraser_thr2(remover);
// Monitor progress.
// If programs begins to print same values again and again
// the lockfree_slist is broken.
std::string prev;
while(1) {
std::stringstream out;
out << "Insertions: "<<insertions.exchange(0);
out <<" Last:"<<insertion_latest;
out <<" Erasures:"<<erasures.exchange(0);
out <<" Last:"<<erasures_latest;
std::cout<<out.str()<<std::endl;
std::this_thread::sleep_for(std::chrono::seconds(1));
prev = out.str();
}
run = false;
push_thr.join();
push_thr2.join();
eraser_thr.join();
eraser_thr2.join();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment