Skip to content

Instantly share code, notes, and snippets.

@GavinRay97
Last active January 13, 2024 01:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GavinRay97/c4a3e8bf5b20045f367b42a33c9f3bc2 to your computer and use it in GitHub Desktop.
Save GavinRay97/c4a3e8bf5b20045f367b42a33c9f3bc2 to your computer and use it in GitHub Desktop.
[C++] Atomic Tagged Pointer (uint16 version counter + 3-bit tag/storage)
#include <atomic>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <optional>
// A word-aligned, atomic tagged pointer.
// Uses both the upper 16 bits for storage, and the lower 3 bits for tagging.
//
// 64 48 32 16
// 0xXXXXXXXXXXXXXXXX 0000000000000000 0000000000000000 0000000000000XXX
// ^ ^ ^
// | | +-- Tag (3 bits)
// | +-- Pointer (48 bits)
// +-- Counter (16 bits)
//
//
// The tag is 3 bits, allowing for up to 8 different tags.
//
// The version is incremented every time the pointer is CAS'd. This is useful
// for detecting concurrent modifications to a pointer.
template <typename T>
struct AtomicTaggedPointer
{
static_assert(sizeof(T*) == 8, "T* must be 8 bytes");
static_assert(alignof(T*) == 8, "T* must be 8-byte aligned");
union Pointer {
uintptr_t value;
// Little endian
struct
{
uintptr_t pointer : 48;
uint16_t counter : 16;
};
};
static_assert(sizeof(Pointer) == 8, "Pointer must be 8 bytes");
private:
std::atomic<Pointer> m_pointer;
public:
AtomicTaggedPointer(T* pointer = nullptr, uint8_t tag = 0)
{
m_pointer.store({
.pointer = reinterpret_cast<uintptr_t>(pointer) | (tag & 0x7),
.counter = 0,
});
}
T* get(uint8_t* tag = nullptr) const
{
Pointer p = m_pointer.load();
if (tag != nullptr)
{
*tag = p.pointer & 0x7;
}
return reinterpret_cast<T*>(p.pointer & ~0x7);
}
uint16_t get_version() const
{
return m_pointer.load().counter;
}
uint8_t get_tag() const
{
return m_pointer.load().pointer & 0x7;
}
bool compare_exchange(T* desired, std::optional<uint8_t> desired_tag = std::nullopt)
{
Pointer expected = m_pointer.load();
Pointer desired_pointer = {
.pointer = reinterpret_cast<uintptr_t>(desired) | (desired_tag.value_or(0) & 0x7),
.counter = static_cast<uint16_t>(expected.counter + 1),
};
return m_pointer.compare_exchange_strong(expected, desired_pointer);
}
};
int
main()
{
AtomicTaggedPointer<int> p;
assert(p.get() == nullptr);
AtomicTaggedPointer<int> p2 = AtomicTaggedPointer<int>(new int(42));
assert(*p2.get() == 42);
uint8_t tag;
assert(p2.get(&tag) == p2.get());
assert(tag == 0);
int* desired = new int(43);
assert(p2.compare_exchange(desired, 1));
assert(*p2.get() == 43);
assert(p2.get(&tag) == desired);
assert(tag == 1);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment