Last active
March 12, 2024 07:44
-
-
Save jweinst1/3676d75f3b48474641f357bf46315c5d to your computer and use it in GitHub Desktop.
a hazard pointer and ref count combo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <atomic> | |
#include <thread> | |
#include <chrono> | |
#include <cstdint> | |
#include <limits> | |
#include <cassert> | |
#include <cstdio> | |
struct HazardNode { | |
std::atomic<void*> ptr = nullptr; | |
std::atomic<unsigned> refcnt = 1; | |
// these are never deleted. | |
struct HazardNode* next; | |
// Only safe to use from an existing reference. | |
void incRef() { | |
refcnt.fetch_add(1); | |
} | |
// meant for lookups | |
bool incRefChecked() { | |
unsigned seen = refcnt.load(); | |
if (!seen) | |
return false; | |
unsigned desired = seen + 1; | |
while(!refcnt.compare_exchange_weak(seen, desired)) { | |
if (!seen) | |
return false; | |
desired = seen + 1; | |
} | |
return true; | |
} | |
void* decRef() { | |
unsigned seen = refcnt.load(); | |
if (!seen) { | |
return nullptr; | |
} | |
unsigned desired = seen - 1; | |
while(!refcnt.compare_exchange_weak(seen, desired)) { | |
if (!seen) | |
return nullptr; | |
desired = seen - 1; | |
} | |
if (seen == 1) { | |
void* to_delete = ptr.exchange(nullptr); | |
return to_delete; | |
} | |
return nullptr; | |
} | |
}; | |
struct HazardList { | |
std::atomic<HazardNode*> hlist = nullptr; | |
HazardNode* findEmptySlot(void* ptr, size_t times) { | |
for (size_t i = 0; i < times; ++i) { | |
HazardNode* iter = hlist.load(); | |
while (iter != nullptr) { | |
void* slotcheck = iter->ptr.load(); | |
if (slotcheck != nullptr) { | |
continue; | |
} | |
// placement logic. | |
if (iter->ptr.compare_exchange_strong(slotcheck, ptr)) { | |
// take away at later point | |
unsigned exres = iter->refcnt.exchange(1); | |
if (exres != 0) { | |
printf("got %u in exres\n", exres); | |
assert(0); | |
} | |
return iter; | |
} | |
iter = iter->next; | |
} | |
} | |
return nullptr; | |
} | |
HazardNode* addPointer(void* ptr, bool tryToReuse = true) { | |
HazardNode* tryToUseExisting = tryToReuse ? findEmptySlot(ptr, 2) : nullptr; | |
if (tryToUseExisting != nullptr) { | |
return tryToUseExisting; | |
} | |
HazardNode* newnode = new HazardNode(); | |
newnode->ptr.store(ptr); | |
HazardNode* got = hlist.load(); | |
newnode->next = got; | |
while(!hlist.compare_exchange_weak(got, newnode)) { | |
newnode->next = got; | |
} | |
return newnode; | |
} | |
unsigned getRefCount(void* ptr) { | |
HazardNode* iter = hlist.load(); | |
while (iter != nullptr) { | |
if (iter->ptr.load() == ptr) | |
return iter->refcnt.load(); | |
iter = iter->next; | |
} | |
return 0; | |
} | |
HazardNode* createPointerRef(void* ptr) { | |
HazardNode* iter = hlist.load(); | |
while (iter != nullptr) { | |
void* slotcheck = iter->ptr.load(); | |
if (slotcheck == ptr) { | |
if (iter->incRefChecked()) { | |
return iter; | |
} else { | |
printf("Failed to increment!!!\n"); | |
return nullptr; | |
} | |
} | |
iter = iter->next; | |
} | |
return nullptr; | |
} | |
}; | |
#define TPOOL_SIZE 8 | |
static void thread_test() { | |
using namespace std::chrono_literals; | |
std::thread tpool[TPOOL_SIZE]; | |
std::atomic<void*> ptrpool[TPOOL_SIZE]; | |
HazardList hl; | |
std::atomic<bool> keepGoing = true; | |
for (int i = 0; i < TPOOL_SIZE; ++i) | |
{ | |
tpool[i] = std::thread([&, i]{ | |
unsigned total = 0; | |
printf("Using ptr slot %d\n", i); | |
HazardNode* myPtr = nullptr; | |
while (keepGoing.load()) { | |
void* loader = ptrpool[i].load(); | |
if (loader == nullptr) { | |
int* put_int = new int(i); | |
myPtr = hl.addPointer(put_int); | |
if(!ptrpool[i].compare_exchange_strong(loader, put_int)) { | |
myPtr->decRef(); | |
delete put_int; | |
} | |
} | |
void* hptr = ptrpool[i].load(); | |
myPtr = myPtr == nullptr ? hl.createPointerRef(hptr) : myPtr; | |
if (myPtr == nullptr) { | |
myPtr = hl.addPointer(hptr); | |
} | |
void * result = myPtr->decRef(); | |
if (result != nullptr) { | |
//printf("thread %d, got pointer %p\n for free\n", i, result); | |
void* seen2 = ptrpool[i].load(); | |
if (seen2 == result) { | |
void* ptrres = ptrpool[i].exchange(nullptr); | |
if (ptrres != result) { | |
printf("ERR, thread %d got %p when ptr swap\n", i, ptrres); | |
} | |
delete (int*)result; | |
} | |
} | |
for (int k = 0; k < TPOOL_SIZE; ++k) | |
{ | |
void* kptr = ptrpool[k].load(); | |
if (kptr != nullptr) { | |
auto nPtr = hl.createPointerRef(kptr); | |
void* result2 = nullptr; | |
if (nPtr != nullptr) | |
result2 = nPtr->decRef(); | |
if (result2 != nullptr) { | |
//printf("thread %d, got pointer %p\n for free\n", i, result2); | |
void* seen2 = ptrpool[k].load(); | |
if (seen2 == result2) { | |
void* ptrres = ptrpool[k].exchange(nullptr); | |
if (ptrres != result2) { | |
printf("ERR, thread %d got %p when ptr swap\n", i, ptrres); | |
} | |
delete (int*)result2; | |
} | |
} | |
} | |
++total; | |
} | |
} | |
printf("Got total ops %u\n", total); | |
}); | |
} | |
std::this_thread::sleep_for(1000ms); | |
keepGoing.store(false); | |
printf("now joining\n"); | |
for (int j = 0; j < TPOOL_SIZE; ++j) | |
{ | |
tpool[j].join(); | |
} | |
} | |
int main(int argc, char const *argv[]) | |
{ | |
thread_test(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment