This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cinttypes> | |
#include <atomic> | |
#include <memory> | |
#include <type_traits> | |
//start 00:20 | |
// genius idea https://en.wikipedia.org/wiki/Tagged_pointer | |
template <class T> | |
class TaggedPtr { | |
static constexpr uint64_t TagIndex = 3; | |
static constexpr uint64_t PtrMask = (1ULL << 48) - 1; | |
union Tags { | |
uint64_t value; | |
uint16_t tag[4]; | |
}; | |
static T* ExtractPtr(uint64_t p) { | |
return reinterpret_cast<T*>(p & PtrMask); | |
} | |
static uint64_t PackPtr(T* p, uint16_t tag) { | |
Tags t; | |
t.value = reinterpret_cast<uint64_t>(p); | |
t.tag[TagIndex] = tag; | |
return t.value; | |
} | |
public: | |
TaggedPtr() = default; | |
TaggedPtr(T* p, uint16_t tag = 0) | |
: ptr(PackPtr(p, tag)) | |
{} | |
TaggedPtr& operator=(const TaggedPtr&) = default; | |
T* GetPtr() const { | |
return ExtractPtr(ptr); | |
} | |
uint16_t GetTag() const { | |
Tags t; | |
t.value = ptr; | |
return t.tag[TagIndex]; | |
} | |
void SetPtr(T* p) { | |
uint16_t tag = GetTag(); | |
ptr = PackPtr(p, tag); | |
} | |
uint16_t GetNextTag() const { | |
return GetTag() + uint16_t(1); | |
} | |
bool operator==(const TaggedPtr& p) const { | |
return ptr == p.ptr; | |
} | |
bool operator!=(const TaggedPtr& p) const { | |
return ptr != p.ptr; | |
} | |
T* operator->() const { | |
return GetPtr(); | |
} | |
operator bool() const { | |
return GetPtr() != 0; | |
} | |
private: | |
volatile uint64_t ptr; | |
}; | |
template <class T> | |
class Pool : std::allocator<T> { | |
using Alloc = std::allocator<T>; | |
struct Node { | |
TaggedPtr<Node> Next; | |
}; | |
public: | |
void deallocate(T* n) { | |
void* node = n; | |
TaggedPtr<Node> old = Pool_.load(); | |
Node* newPoolPtr = reinterpret_cast<Node*>(node); | |
while (true) { | |
TaggedPtr<Node> newPool(newPoolPtr, old.GetTag()); | |
newPool->Next.SetPtr(old.GetPtr()); | |
if (Pool_.compare_exchange_strong(old, newPool)) { | |
return; | |
} | |
} | |
} | |
Pool() : Pool_{TaggedPtr<Node>(nullptr)} {} | |
~Pool() { | |
TaggedPtr<Node> head = Pool_.load(); | |
while (head) { | |
Node* ptr = head.GetPtr(); | |
if (ptr) { | |
head = ptr->Next; | |
} | |
Alloc::deallocate((T*)ptr, 1); | |
} | |
} | |
T* allocate() { | |
TaggedPtr<Node> old = Pool_.load(); | |
while (true) { | |
if (!old.GetPtr()) { | |
return Alloc::allocate(1); | |
} | |
Node* newPoolPtr = old->Next.GetPtr(); | |
TaggedPtr<Node> newPool(newPoolPtr, old.GetNextTag()); | |
if (Pool_.compare_exchange_weak(old, newPool)) { | |
return reinterpret_cast<T*>(old.GetPtr()); | |
} | |
} | |
} | |
T* Construct() { | |
T* node = allocate(); | |
if (node) { | |
new(node) T(); | |
} | |
return node; | |
} | |
template <class U> | |
T* Construct(const U& u) { | |
T* node = allocate(); | |
if (node) { | |
new(node) T(u); | |
} | |
return node; | |
} | |
private: | |
std::atomic<TaggedPtr<Node>> Pool_; | |
}; | |
template <class T> | |
class TMichaelScottQueue { | |
static_assert(std::is_trivially_destructible<T>::value, "T must be trivially destructible"); | |
static_assert(std::is_trivially_copyable<T>::value, "T must be trivially copyable"); | |
struct TNode { | |
TNode() = default; | |
TNode(const T& v) : Data(v) { | |
TaggedPtr<TNode> old = Next.load(); | |
TaggedPtr<TNode> newNext(static_cast<TNode*>(nullptr), old.GetNextTag()); | |
Next.store(newNext); | |
} | |
std::atomic<TaggedPtr<TNode>> Next{nullptr}; | |
T Data; | |
}; | |
TMichaelScottQueue(const TMichaelScottQueue&) = delete; | |
TMichaelScottQueue& operator=(const TMichaelScottQueue& queue) = delete; | |
public: | |
TMichaelScottQueue() : Pool_() { | |
TNode* node = Pool_.Construct(); | |
TaggedPtr<TNode> dummy(node, 0); | |
Head_.store(dummy); | |
Tail_.store(dummy); | |
} | |
~TMichaelScottQueue() { | |
while (pop().first) {} | |
Pool_.deallocate(Head_.load().GetPtr()); | |
} | |
std::pair<bool, T> pop() { | |
while (true) { | |
TaggedPtr<TNode> head = Head_.load(); | |
TNode* headPtr = head.GetPtr(); | |
TaggedPtr<TNode> tail = Tail_.load(); | |
TaggedPtr<TNode> next = headPtr->Next.load(); | |
TNode* nextPtr = next.GetPtr(); | |
if (head == Head_.load()) { | |
if (head.GetPtr() == tail.GetPtr()) { | |
if (!nextPtr) { | |
return {false, T()}; | |
} | |
TaggedPtr<TNode> newTail(next.GetPtr(), tail.GetNextTag()); | |
Tail_.compare_exchange_strong(tail, newTail); | |
} else { | |
std::pair<bool, T> ret = {true, nextPtr->Data}; | |
TaggedPtr<TNode> newHead(next.GetPtr(), head.GetNextTag()); | |
if (Head_.compare_exchange_weak(head, newHead)) { | |
Pool_.deallocate(head.GetPtr()); | |
return ret; | |
} | |
} | |
} | |
} | |
} | |
void push(T value) { | |
TNode* node = Pool_.Construct(value); | |
TaggedPtr<TNode> tail, newTail; | |
while (true) { | |
tail = Tail_.load(); | |
TNode* tailNode = tail.GetPtr(); | |
TaggedPtr<TNode> next = tailNode->Next.load(); | |
TNode* nextPtr = next.GetPtr(); | |
if (tail == Tail_.load()) { | |
if (!nextPtr) { | |
TaggedPtr<TNode> newTailNext(node, next.GetNextTag()); | |
if (tailNode->Next.compare_exchange_strong(next, newTailNext)) { | |
newTail = TaggedPtr<TNode>(node, tail.GetNextTag()); | |
break; | |
} | |
} else { | |
newTail = TaggedPtr<TNode>(nextPtr, tail.GetNextTag()); | |
Tail_.compare_exchange_strong(tail, newTail); | |
} | |
} | |
} | |
Tail_.compare_exchange_strong(tail, newTail); | |
} | |
private: | |
std::atomic<TaggedPtr<TNode>> Head_{nullptr}; | |
std::atomic<TaggedPtr<TNode>> Tail_{nullptr}; | |
Pool<TNode> Pool_; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment