Skip to content

Instantly share code, notes, and snippets.

@jweinst1
Last active March 12, 2024 07:44
Show Gist options
  • Save jweinst1/3676d75f3b48474641f357bf46315c5d to your computer and use it in GitHub Desktop.
Save jweinst1/3676d75f3b48474641f357bf46315c5d to your computer and use it in GitHub Desktop.
a hazard pointer and ref count combo
#include <atomic>
#include <thread>
#include <chrono>
#include <cstdint>
#include <limits>
#include <cassert>
#include <cstdio>
struct HazardNode {
std::atomic<void*> ptr = nullptr;
std::atomic<unsigned> refcnt = 1;
// these are never deleted.
struct HazardNode* next;
// Only safe to use from an existing reference.
void incRef() {
refcnt.fetch_add(1);
}
// meant for lookups
bool incRefChecked() {
unsigned seen = refcnt.load();
if (!seen)
return false;
unsigned desired = seen + 1;
while(!refcnt.compare_exchange_weak(seen, desired)) {
if (!seen)
return false;
desired = seen + 1;
}
return true;
}
void* decRef() {
unsigned seen = refcnt.load();
if (!seen) {
return nullptr;
}
unsigned desired = seen - 1;
while(!refcnt.compare_exchange_weak(seen, desired)) {
if (!seen)
return nullptr;
desired = seen - 1;
}
if (seen == 1) {
void* to_delete = ptr.exchange(nullptr);
return to_delete;
}
return nullptr;
}
};
struct HazardList {
std::atomic<HazardNode*> hlist = nullptr;
HazardNode* findEmptySlot(void* ptr, size_t times) {
for (size_t i = 0; i < times; ++i) {
HazardNode* iter = hlist.load();
while (iter != nullptr) {
void* slotcheck = iter->ptr.load();
if (slotcheck != nullptr) {
continue;
}
// placement logic.
if (iter->ptr.compare_exchange_strong(slotcheck, ptr)) {
// take away at later point
unsigned exres = iter->refcnt.exchange(1);
if (exres != 0) {
printf("got %u in exres\n", exres);
assert(0);
}
return iter;
}
iter = iter->next;
}
}
return nullptr;
}
HazardNode* addPointer(void* ptr, bool tryToReuse = true) {
HazardNode* tryToUseExisting = tryToReuse ? findEmptySlot(ptr, 2) : nullptr;
if (tryToUseExisting != nullptr) {
return tryToUseExisting;
}
HazardNode* newnode = new HazardNode();
newnode->ptr.store(ptr);
HazardNode* got = hlist.load();
newnode->next = got;
while(!hlist.compare_exchange_weak(got, newnode)) {
newnode->next = got;
}
return newnode;
}
unsigned getRefCount(void* ptr) {
HazardNode* iter = hlist.load();
while (iter != nullptr) {
if (iter->ptr.load() == ptr)
return iter->refcnt.load();
iter = iter->next;
}
return 0;
}
HazardNode* createPointerRef(void* ptr) {
HazardNode* iter = hlist.load();
while (iter != nullptr) {
void* slotcheck = iter->ptr.load();
if (slotcheck == ptr) {
if (iter->incRefChecked()) {
return iter;
} else {
printf("Failed to increment!!!\n");
return nullptr;
}
}
iter = iter->next;
}
return nullptr;
}
};
#define TPOOL_SIZE 8
static void thread_test() {
using namespace std::chrono_literals;
std::thread tpool[TPOOL_SIZE];
std::atomic<void*> ptrpool[TPOOL_SIZE];
HazardList hl;
std::atomic<bool> keepGoing = true;
for (int i = 0; i < TPOOL_SIZE; ++i)
{
tpool[i] = std::thread([&, i]{
unsigned total = 0;
printf("Using ptr slot %d\n", i);
HazardNode* myPtr = nullptr;
while (keepGoing.load()) {
void* loader = ptrpool[i].load();
if (loader == nullptr) {
int* put_int = new int(i);
myPtr = hl.addPointer(put_int);
if(!ptrpool[i].compare_exchange_strong(loader, put_int)) {
myPtr->decRef();
delete put_int;
}
}
void* hptr = ptrpool[i].load();
myPtr = myPtr == nullptr ? hl.createPointerRef(hptr) : myPtr;
if (myPtr == nullptr) {
myPtr = hl.addPointer(hptr);
}
void * result = myPtr->decRef();
if (result != nullptr) {
//printf("thread %d, got pointer %p\n for free\n", i, result);
void* seen2 = ptrpool[i].load();
if (seen2 == result) {
void* ptrres = ptrpool[i].exchange(nullptr);
if (ptrres != result) {
printf("ERR, thread %d got %p when ptr swap\n", i, ptrres);
}
delete (int*)result;
}
}
for (int k = 0; k < TPOOL_SIZE; ++k)
{
void* kptr = ptrpool[k].load();
if (kptr != nullptr) {
auto nPtr = hl.createPointerRef(kptr);
void* result2 = nullptr;
if (nPtr != nullptr)
result2 = nPtr->decRef();
if (result2 != nullptr) {
//printf("thread %d, got pointer %p\n for free\n", i, result2);
void* seen2 = ptrpool[k].load();
if (seen2 == result2) {
void* ptrres = ptrpool[k].exchange(nullptr);
if (ptrres != result2) {
printf("ERR, thread %d got %p when ptr swap\n", i, ptrres);
}
delete (int*)result2;
}
}
}
++total;
}
}
printf("Got total ops %u\n", total);
});
}
std::this_thread::sleep_for(1000ms);
keepGoing.store(false);
printf("now joining\n");
for (int j = 0; j < TPOOL_SIZE; ++j)
{
tpool[j].join();
}
}
int main(int argc, char const *argv[])
{
thread_test();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment