Last active
June 21, 2019 14:49
-
-
Save JATothrim/97aef4c46f5e4391eb07886732493078 to your computer and use it in GitHub Desktop.
Lock-free slist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <functional> | |
#include <algorithm> | |
#include <numeric> | |
#include <memory> | |
#include <atomic> | |
#include <thread> | |
#include <iostream> | |
#include <sstream> | |
#include <random> | |
#include <array> | |
/** | |
* \struct tagged_ptr<> special pointer for lock-free algorithms. | |
* \brief tagged_ptr<> embeds an ABA tag counter into the ptr. | |
* \note tagged_ptr<> has strict aligment requirments: | |
* #) std::atomic< tagged_ptr< X > > objects should be aligned to cpu cache-line width. | |
* #) All tagged_ptr< X > objects must aligned to minimum of sizeof(tagged_ptr< X >). | |
*/ | |
template<typename T> | |
struct alignas(16) tagged_ptr { | |
typedef typename std::conditional<sizeof(T*) == sizeof(uint32_t), uint32_t, uint16_t>::type tag_type; | |
union { | |
T* ptr; | |
struct { | |
char _padd[sizeof(T*) - (sizeof(T*) == sizeof(uint32_t) ? 0 : sizeof(uint16_t))]; | |
tag_type tag; | |
}; | |
}; | |
// Tagged ptr must be trivially copyable so don't define any ctors. | |
// Get pointer value. | |
T * get() const { | |
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> tmp; | |
tmp.ptr = ptr; | |
tmp.tag = 0U; | |
return reinterpret_cast<T*>(tmp.ptr); | |
} | |
// the tagged_ptr can implictly converted to regular pointer. | |
operator T*() const { | |
return get(); | |
} | |
// check if tagged_ptr is null. | |
operator bool() const { | |
return get() != nullptr; | |
} | |
// tagged ptr can also be dereferenced.. | |
T & operator*() const { | |
return *get(); | |
} | |
T * operator->() const { | |
return get(); | |
} | |
// array indexed. | |
T & operator[](size_t nth) const { | |
return get()[nth]; | |
} | |
// Compare tagged ptr with tag value. | |
bool operator==(const tagged_ptr<T> & rhs) const { | |
return this->ptr == rhs.ptr; | |
} | |
bool operator!=(const tagged_ptr<T> & rhs) const { | |
return this->ptr != rhs.ptr; | |
} | |
// Increment the ABA-tag and store new ptr. | |
tagged_ptr<T> operator()(T * ptr) const { | |
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> aba; | |
aba.ptr = ptr; | |
aba.tag = tag + 1; | |
return aba; | |
} | |
tagged_ptr<T> operator()(tagged_ptr<T> aba) const { | |
aba.tag = tag + 1; | |
return aba; | |
} | |
// Increment the ABA-tag and return ptr. | |
tagged_ptr<T> operator()() const { | |
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> aba; | |
aba.ptr = ptr; | |
aba.tag = tag + 1; | |
return aba; | |
} | |
// return tag value. | |
unsigned int get_tag() const { | |
return tag; | |
} | |
// Check if all address bits are set | |
// (a.ka void/invalid ptr) ignoring the tag bits. | |
bool is_void() const { | |
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> c, v; | |
c.ptr = ptr; | |
c.tag = 0U; | |
v.ptr = reinterpret_cast<T*>(~0UL); | |
v.tag = 0U; | |
return c.ptr == v.ptr; | |
} | |
// Set all address bits (make a void/invalid ptr) | |
tagged_ptr<T> set_void() { | |
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> aba; | |
aba.ptr = reinterpret_cast<T*>(~0UL); | |
aba.tag = tag; | |
return *this; | |
} | |
}; | |
static_assert(sizeof(tagged_ptr<int>) >= sizeof(void*), "tagged_ptr<X> is miscompled!"); | |
/** | |
* make tagged_ptr<T> from ptr with zero as tag value. | |
*/ | |
template<typename T> | |
inline tagged_ptr<T> ptr_tag(T * ptr) | |
{ | |
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> tp; | |
tp.ptr = ptr; | |
tp.tag = 0U; | |
return tp; | |
} | |
/** | |
* make tagged_ptr<T> from ptr with specified tag value. | |
*/ | |
template<typename T> | |
inline tagged_ptr<T> ptr_tag(T * ptr, typename tagged_ptr<T>::tag_type tag) | |
{ | |
alignas(sizeof(tagged_ptr<T>)) tagged_ptr<T> tp; | |
tp.ptr = ptr; | |
tp.tag = tag; | |
return tp; | |
} | |
/** | |
* make a null tagged_ptr<T> with zero tag. | |
*/ | |
template<typename T> | |
inline tagged_ptr<T> nullptr_tag() | |
{ | |
return ptr_tag<T>(nullptr); | |
} | |
/** | |
* make void/invalid tagged_ptr<T> with zero tag | |
*/ | |
template<typename T> | |
inline tagged_ptr<T> voidptr_tag() | |
{ | |
return ptr_tag<T>(reinterpret_cast<T*>(~0UL)); | |
} | |
/** | |
* Lock-free list that allows push at front | |
* and removal at any position. | |
* Assumes list user handles allocation | |
* and safe reclamation of list nodes. | |
*/ | |
class lockfree_slist { | |
public: | |
struct list_node { | |
std::atomic<tagged_ptr<list_node>> next; | |
int data; | |
}; | |
// List head ptr | |
std::atomic<tagged_ptr<list_node>> head; | |
lockfree_slist() : head(nullptr_tag<list_node>()) {} | |
// Push new value at list front. | |
void push_front(list_node * node) { | |
auto hdr = head.load(std::memory_order_acquire); | |
do { | |
node->next.store(hdr); | |
} while(!head.compare_exchange_weak(hdr, hdr(node))); | |
} | |
// Try Find first node with criteria | |
// predicate arguments are before-ptr and dereferenced-ptr | |
list_node* find(const std::function<bool(std::atomic<tagged_ptr<list_node>>*, | |
tagged_ptr<list_node>)> & func) { | |
auto ptr = &head; | |
auto itr = ptr->load(std::memory_order_acquire); | |
while(itr) { | |
if(itr.is_void()) { | |
// Thread interfered iteration. | |
ptr = &head; | |
itr = ptr->load(std::memory_order_acquire); | |
} else if(func(ptr, itr)) { | |
return itr.get(); | |
} else { | |
// Access next ptr | |
ptr = &itr->next; | |
itr = ptr->load(std::memory_order_acquire); | |
} | |
} | |
return nullptr; | |
} | |
// Get approximate size of the list. | |
// return std::numeric_limits<size_t>::max() if thread | |
// interfered. | |
size_t count() { | |
size_t count = 0; | |
auto ptr = &head; | |
auto itr = ptr->load(std::memory_order_acquire); | |
while(itr) { | |
if(itr.is_void()) { | |
// Thread interfered iteration. | |
ptr = &head; | |
itr = ptr->load(std::memory_order_acquire); | |
count = 0; | |
} else { | |
++count; | |
// Access next ptr | |
ptr = &itr->next; | |
itr = ptr->load(std::memory_order_acquire); | |
} | |
} | |
return count; | |
} | |
// Pop item at list front. | |
list_node * pop_front() { | |
auto node = head.load(std::memory_order_acquire); | |
while(node) { | |
auto after = node->next.load(std::memory_order_acquire); | |
if(after.is_void()) { | |
node = head.load(std::memory_order_acquire); | |
} else { | |
// Attempt pop | |
list_node * tmp = node.get(); | |
if(head.compare_exchange_weak(node, node(after))) { | |
while(!tmp->next.compare_exchange_weak(after, after().set_void())) | |
; | |
return tmp; | |
} | |
} | |
} | |
return nullptr; | |
} | |
// Try erase node from the list | |
// @return false if the node was not linked in the list. | |
bool erase(list_node * node) { | |
std::atomic<tagged_ptr<list_node>>* before; | |
tagged_ptr<list_node> itr, after; | |
for(;;) { | |
// Find previous (or head) before-node-ptr | |
before = &head; | |
itr = before->load(std::memory_order_acquire); | |
while(itr) { | |
if(itr.get() == node) { | |
break; | |
} else if(itr.is_void()) { | |
// Thread interfered iteration. | |
before = &head; | |
itr = before->load(std::memory_order_acquire); | |
} else { | |
// Access next ptr | |
before = &itr->next; | |
itr = before->load(std::memory_order_acquire); | |
} | |
} | |
after = node->next.load(std::memory_order_acquire); | |
if(after.is_void() || !itr) { | |
return false; | |
} | |
// list now:before --> node/itr --> after --> ... --> null | |
// Point before-ptr to after. (set head or previous node's next ptr) | |
if(before->compare_exchange_strong(itr, itr(after))) { | |
// list now:before --> after --> ... --> null, node --> after --> ... --> null | |
// Set node->next to invalid ptr. | |
while(!node->next.compare_exchange_weak(after, after().set_void())) | |
; | |
// list now:before --> after --> ... --> null, node --> invalid | |
return true; | |
} | |
// If *before changed while trying to update it to after, retry search. | |
} | |
} | |
}; | |
/** | |
* Test code | |
*/ | |
// Compare ist node value with k | |
bool cmpval(std::atomic<tagged_ptr<lockfree_slist::list_node>>* ptr, | |
tagged_ptr<lockfree_slist::list_node> ref, int k) | |
{ | |
return ref->data == k; | |
} | |
std::atomic<int> insertions(0); | |
std::atomic<int> insertion_latest(0); | |
std::atomic<int> erasures(0); | |
std::atomic<int> erasures_latest(0); | |
// Lock-free list instance. | |
lockfree_slist list; | |
int main() { | |
const int MAX_LIST_SIZE = 32; | |
// List nodes are simply stored in array | |
std::array<lockfree_slist::list_node,MAX_LIST_SIZE> nodes; | |
// pusher and remover recycle nodes using this array. | |
std::array<std::atomic<lockfree_slist::list_node*>,MAX_LIST_SIZE> ptrs; | |
std::atomic<bool> run(true); | |
for(int i = 0; i < nodes.size(); ++i) { | |
ptrs[i] = &nodes[i]; | |
nodes[i].next.store(voidptr_tag<lockfree_slist::list_node>()); | |
nodes[i].data = i; | |
} | |
// Pusher tries to maintain set of ints | |
// that has no gaps. | |
auto pusher = [&run,&nodes,&ptrs]() { | |
while(run) { | |
lockfree_slist::list_node * nn = nullptr; | |
for(int i = 0; i < nodes.size(); ++i) { | |
nn = ptrs[i].exchange(nullptr); | |
if(nn) { | |
break; | |
} | |
} | |
++insertions; | |
insertion_latest = nn->data; | |
list.push_front(nn); | |
} | |
}; | |
// Remover mixes if the list gets too long by erasing | |
// nodes at random points. | |
auto remover = [&run,&ptrs]() { | |
std::random_device rd; | |
std::mt19937 rng; | |
rng.seed(rd()); | |
while(run) { | |
// Find random item from the list. | |
int val = std::uniform_int_distribution<int>(0,ptrs.size())(rng); | |
auto node = list.find(std::bind(cmpval,std::placeholders::_1, std::placeholders::_2, val)); | |
// Try erase the node. | |
if(node && list.erase(node)) { | |
++erasures; | |
erasures_latest = node->data; | |
// Recycle node. | |
lockfree_slist::list_node * nn; | |
for(int i = 0; i < ptrs.size(); ++i) { | |
nn = nullptr; | |
if(ptrs[i].compare_exchange_strong(nn, node)) | |
break; | |
} | |
} | |
} | |
}; | |
std::thread push_thr(pusher); | |
std::thread push_thr2(pusher); | |
std::thread eraser_thr(remover); | |
std::thread eraser_thr2(remover); | |
// Monitor progress. | |
// If programs begins to print same values again and again | |
// the lockfree_slist is broken. | |
std::string prev; | |
while(1) { | |
std::stringstream out; | |
out << "Insertions: "<<insertions.exchange(0); | |
out <<" Last:"<<insertion_latest; | |
out <<" Erasures:"<<erasures.exchange(0); | |
out <<" Last:"<<erasures_latest; | |
std::cout<<out.str()<<std::endl; | |
std::this_thread::sleep_for(std::chrono::seconds(1)); | |
prev = out.str(); | |
} | |
run = false; | |
push_thr.join(); | |
push_thr2.join(); | |
eraser_thr.join(); | |
eraser_thr2.join(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment