Last active
January 13, 2024 01:45
-
-
Save GavinRay97/c4a3e8bf5b20045f367b42a33c9f3bc2 to your computer and use it in GitHub Desktop.
[C++] Atomic Tagged Pointer (uint16 version counter + 3-bit tag/storage)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <atomic> | |
#include <cassert> | |
#include <cstdint> | |
#include <cstdio> | |
#include <optional> | |
// A word-aligned, atomic tagged pointer. | |
// Uses both the upper 16 bits for storage, and the lower 3 bits for tagging. | |
// | |
// 64 48 32 16 | |
// 0xXXXXXXXXXXXXXXXX 0000000000000000 0000000000000000 0000000000000XXX | |
// ^ ^ ^ | |
// | | +-- Tag (3 bits) | |
// | +-- Pointer (48 bits) | |
// +-- Counter (16 bits) | |
// | |
// | |
// The tag is 3 bits, allowing for up to 8 different tags. | |
// | |
// The version is incremented every time the pointer is CAS'd. This is useful | |
// for detecting concurrent modifications to a pointer. | |
template <typename T> | |
struct AtomicTaggedPointer | |
{ | |
static_assert(sizeof(T*) == 8, "T* must be 8 bytes"); | |
static_assert(alignof(T*) == 8, "T* must be 8-byte aligned"); | |
union Pointer { | |
uintptr_t value; | |
// Little endian | |
struct | |
{ | |
uintptr_t pointer : 48; | |
uint16_t counter : 16; | |
}; | |
}; | |
static_assert(sizeof(Pointer) == 8, "Pointer must be 8 bytes"); | |
private: | |
std::atomic<Pointer> m_pointer; | |
public: | |
AtomicTaggedPointer(T* pointer = nullptr, uint8_t tag = 0) | |
{ | |
m_pointer.store({ | |
.pointer = reinterpret_cast<uintptr_t>(pointer) | (tag & 0x7), | |
.counter = 0, | |
}); | |
} | |
T* get(uint8_t* tag = nullptr) const | |
{ | |
Pointer p = m_pointer.load(); | |
if (tag != nullptr) | |
{ | |
*tag = p.pointer & 0x7; | |
} | |
return reinterpret_cast<T*>(p.pointer & ~0x7); | |
} | |
uint16_t get_version() const | |
{ | |
return m_pointer.load().counter; | |
} | |
uint8_t get_tag() const | |
{ | |
return m_pointer.load().pointer & 0x7; | |
} | |
bool compare_exchange(T* desired, std::optional<uint8_t> desired_tag = std::nullopt) | |
{ | |
Pointer expected = m_pointer.load(); | |
Pointer desired_pointer = { | |
.pointer = reinterpret_cast<uintptr_t>(desired) | (desired_tag.value_or(0) & 0x7), | |
.counter = static_cast<uint16_t>(expected.counter + 1), | |
}; | |
return m_pointer.compare_exchange_strong(expected, desired_pointer); | |
} | |
}; | |
int | |
main() | |
{ | |
AtomicTaggedPointer<int> p; | |
assert(p.get() == nullptr); | |
AtomicTaggedPointer<int> p2 = AtomicTaggedPointer<int>(new int(42)); | |
assert(*p2.get() == 42); | |
uint8_t tag; | |
assert(p2.get(&tag) == p2.get()); | |
assert(tag == 0); | |
int* desired = new int(43); | |
assert(p2.compare_exchange(desired, 1)); | |
assert(*p2.get() == 43); | |
assert(p2.get(&tag) == desired); | |
assert(tag == 1); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment