Skip to content

Instantly share code, notes, and snippets.

@danlark1
Created July 26, 2020 12:16
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danlark1/b499bec5fd32d6c0951cce55f19011e1 to your computer and use it in GitHub Desktop.
Save danlark1/b499bec5fd32d6c0951cce55f19011e1 to your computer and use it in GitHub Desktop.
#include <cinttypes>
#include <atomic>
#include <memory>
#include <type_traits>
//start 00:20
// genius idea https://en.wikipedia.org/wiki/Tagged_pointer
template <class T>
class TaggedPtr {
static constexpr uint64_t TagIndex = 3;
static constexpr uint64_t PtrMask = (1ULL << 48) - 1;
union Tags {
uint64_t value;
uint16_t tag[4];
};
static T* ExtractPtr(uint64_t p) {
return reinterpret_cast<T*>(p & PtrMask);
}
static uint64_t PackPtr(T* p, uint16_t tag) {
Tags t;
t.value = reinterpret_cast<uint64_t>(p);
t.tag[TagIndex] = tag;
return t.value;
}
public:
TaggedPtr() = default;
TaggedPtr(T* p, uint16_t tag = 0)
: ptr(PackPtr(p, tag))
{}
TaggedPtr& operator=(const TaggedPtr&) = default;
T* GetPtr() const {
return ExtractPtr(ptr);
}
uint16_t GetTag() const {
Tags t;
t.value = ptr;
return t.tag[TagIndex];
}
void SetPtr(T* p) {
uint16_t tag = GetTag();
ptr = PackPtr(p, tag);
}
uint16_t GetNextTag() const {
return GetTag() + uint16_t(1);
}
bool operator==(const TaggedPtr& p) const {
return ptr == p.ptr;
}
bool operator!=(const TaggedPtr& p) const {
return ptr != p.ptr;
}
T* operator->() const {
return GetPtr();
}
operator bool() const {
return GetPtr() != 0;
}
private:
volatile uint64_t ptr;
};
template <class T>
class Pool : std::allocator<T> {
using Alloc = std::allocator<T>;
struct Node {
TaggedPtr<Node> Next;
};
public:
void deallocate(T* n) {
void* node = n;
TaggedPtr<Node> old = Pool_.load();
Node* newPoolPtr = reinterpret_cast<Node*>(node);
while (true) {
TaggedPtr<Node> newPool(newPoolPtr, old.GetTag());
newPool->Next.SetPtr(old.GetPtr());
if (Pool_.compare_exchange_strong(old, newPool)) {
return;
}
}
}
Pool() : Pool_{TaggedPtr<Node>(nullptr)} {}
~Pool() {
TaggedPtr<Node> head = Pool_.load();
while (head) {
Node* ptr = head.GetPtr();
if (ptr) {
head = ptr->Next;
}
Alloc::deallocate((T*)ptr, 1);
}
}
T* allocate() {
TaggedPtr<Node> old = Pool_.load();
while (true) {
if (!old.GetPtr()) {
return Alloc::allocate(1);
}
Node* newPoolPtr = old->Next.GetPtr();
TaggedPtr<Node> newPool(newPoolPtr, old.GetNextTag());
if (Pool_.compare_exchange_weak(old, newPool)) {
return reinterpret_cast<T*>(old.GetPtr());
}
}
}
T* Construct() {
T* node = allocate();
if (node) {
new(node) T();
}
return node;
}
template <class U>
T* Construct(const U& u) {
T* node = allocate();
if (node) {
new(node) T(u);
}
return node;
}
private:
std::atomic<TaggedPtr<Node>> Pool_;
};
template <class T>
class TMichaelScottQueue {
static_assert(std::is_trivially_destructible<T>::value, "T must be trivially destructible");
static_assert(std::is_trivially_copyable<T>::value, "T must be trivially copyable");
struct TNode {
TNode() = default;
TNode(const T& v) : Data(v) {
TaggedPtr<TNode> old = Next.load();
TaggedPtr<TNode> newNext(static_cast<TNode*>(nullptr), old.GetNextTag());
Next.store(newNext);
}
std::atomic<TaggedPtr<TNode>> Next{nullptr};
T Data;
};
TMichaelScottQueue(const TMichaelScottQueue&) = delete;
TMichaelScottQueue& operator=(const TMichaelScottQueue& queue) = delete;
public:
TMichaelScottQueue() : Pool_() {
TNode* node = Pool_.Construct();
TaggedPtr<TNode> dummy(node, 0);
Head_.store(dummy);
Tail_.store(dummy);
}
~TMichaelScottQueue() {
while (pop().first) {}
Pool_.deallocate(Head_.load().GetPtr());
}
std::pair<bool, T> pop() {
while (true) {
TaggedPtr<TNode> head = Head_.load();
TNode* headPtr = head.GetPtr();
TaggedPtr<TNode> tail = Tail_.load();
TaggedPtr<TNode> next = headPtr->Next.load();
TNode* nextPtr = next.GetPtr();
if (head == Head_.load()) {
if (head.GetPtr() == tail.GetPtr()) {
if (!nextPtr) {
return {false, T()};
}
TaggedPtr<TNode> newTail(next.GetPtr(), tail.GetNextTag());
Tail_.compare_exchange_strong(tail, newTail);
} else {
std::pair<bool, T> ret = {true, nextPtr->Data};
TaggedPtr<TNode> newHead(next.GetPtr(), head.GetNextTag());
if (Head_.compare_exchange_weak(head, newHead)) {
Pool_.deallocate(head.GetPtr());
return ret;
}
}
}
}
}
void push(T value) {
TNode* node = Pool_.Construct(value);
TaggedPtr<TNode> tail, newTail;
while (true) {
tail = Tail_.load();
TNode* tailNode = tail.GetPtr();
TaggedPtr<TNode> next = tailNode->Next.load();
TNode* nextPtr = next.GetPtr();
if (tail == Tail_.load()) {
if (!nextPtr) {
TaggedPtr<TNode> newTailNext(node, next.GetNextTag());
if (tailNode->Next.compare_exchange_strong(next, newTailNext)) {
newTail = TaggedPtr<TNode>(node, tail.GetNextTag());
break;
}
} else {
newTail = TaggedPtr<TNode>(nextPtr, tail.GetNextTag());
Tail_.compare_exchange_strong(tail, newTail);
}
}
}
Tail_.compare_exchange_strong(tail, newTail);
}
private:
std::atomic<TaggedPtr<TNode>> Head_{nullptr};
std::atomic<TaggedPtr<TNode>> Tail_{nullptr};
Pool<TNode> Pool_;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment