Skip to content

Instantly share code, notes, and snippets.

@jserv
Created August 20, 2021 12:05
Show Gist options
  • Save jserv/c3823ea893e08607b432827a11ec4b69 to your computer and use it in GitHub Desktop.
Save jserv/c3823ea893e08607b432827a11ec4b69 to your computer and use it in GitHub Desktop.
lock-free hashmap
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Darwin)
PRINTF = printf
else
PRINTF = env printf
endif
# Control the build verbosity
ifeq ("$(VERBOSE)","1")
Q :=
VECHO = @true
else
Q := @
VECHO = @$(PRINTF)
endif
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
typedef struct list_node {
struct list_node *next;
void *val;
} list_node_t;
typedef struct {
list_node_t *head;
uint32_t length;
} list_t;
static volatile uint32_t list_retries_empty = 0, list_retries_populated = 0;
static const list_node_t *empty = NULL;
static list_t *list_new()
{
list_t *l = calloc(1, sizeof(list_node_t));
l->head = (list_node_t *) empty;
l->length = 0;
return l;
}
static void list_add(list_t *l, void *val)
{
/* wrap the value as a node in the linked list */
list_node_t *v = calloc(1, sizeof(list_node_t));
v->val = val;
/* try adding to the front of the list */
while (true) {
list_node_t *n = l->head;
if (n == empty) { /* if this is the first link in the list */
v->next = NULL;
if (__atomic_compare_exchange(&l->head, &empty, &v, false,
__ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST)) {
__atomic_fetch_add(&l->length, 1, __ATOMIC_SEQ_CST);
return;
}
list_retries_empty++;
} else { /* inserting when an existing link is present */
v->next = n;
if (__atomic_compare_exchange(&l->head, &n, &v, false,
__ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST)) {
__atomic_fetch_add(&l->length, 1, __ATOMIC_SEQ_CST);
return;
}
list_retries_populated++;
}
}
}
#include "free_later.h"
#define CAS(a, b, c) \
__extension__({ \
typeof(*a) _old = b, _new = c; \
__atomic_compare_exchange(a, &_old, &_new, 0, __ATOMIC_SEQ_CST, \
__ATOMIC_SEQ_CST); \
_old; \
})
static inline void acquire_lock(volatile bool *lock)
{
while (CAS(lock, false, true))
;
}
static inline void release_lock(volatile bool *lock)
{
int l = *lock;
CAS(&l, true, false);
}
typedef struct {
void *var;
void (*free)(void *var);
} free_later_t;
/* track expired variables to cleanup later */
static list_t *buffer = NULL, *buffer_prev = NULL;
int free_later_init()
{
buffer = list_new();
return 0;
}
/* register a var for cleanup */
void free_later(void *var, void release(void *var))
{
free_later_t *cv = malloc(sizeof(free_later_t));
cv->var = var;
cv->free = release;
list_add(buffer, cv);
}
/* signal that worker threads are done with old references */
void free_later_stage(void)
{
/* lock to ensure that multiple threads do not clean up simultaneously */
static bool lock = false;
/* CAS-based lock in case multiple threads are calling this */
acquire_lock(&lock);
if (!buffer_prev || buffer_prev->length == 0) {
release_lock(&lock);
return;
}
/* swap the buffers */
buffer_prev = buffer;
buffer = list_new();
release_lock(&lock);
}
void free_later_run()
{
/* lock to ensure that multiple threads do not clean up simultaneously */
static bool lock = false;
/* skip if there is nothing to return */
if (!buffer_prev)
return;
/* CAS-based lock in case multiple threads are calling this */
acquire_lock(&lock);
/* At this point, all workers have processed one or more new flow since the
* free_later buffer was filled. No threads are using the old, deleted data.
*/
for (list_node_t *n = buffer_prev->head; n; n = n->next) {
free_later_t *v = n->val;
v->free(v->var);
free(n);
}
free(buffer_prev);
buffer_prev = NULL;
release_lock(&lock);
}
int free_later_exit()
{
/* purge anything that is buffered */
free_later_run();
/* stage and purge anything that was unbuffered */
free_later_stage();
free_later_run();
/* release memory for the buffer */
free(buffer);
buffer = NULL;
return 0;
}
/* Memory cleanup for lock-free deletes
*
* Several data structures such as `hashmap_del` will remove data; however, it
* may not be possible to free data until later. For example, if many threads
* are using the same hashmap, more than one may be using a reference when
* `hashmap_del` is called. The solution here is to have `hashmap_del` register
* data that can be deleted later and let the application notify when worker
* threads are done with old references.
*
* `free_later(void *var, void release(void *))` will register a pointer to have
* the `release` method called on it later, when it is safe to free memory.
*
* `free_later_init()` must be called before using `free_later`, and
* `free_later_exit()` should be called before application termination. It'll
* ensure all registerd vars have their `release()` callback invoked.
*
* `free_later_stage()` should be called before a round of work starts. It'll
* stage all buffered values to a list that can't be updated, and make a new
* list to register any new `free_later()` invocations. After all worker threads
* have progressed with work, call `free_later_run()` to have every value in the
* staged buffer released.
*/
#ifndef _FREE_LATER_H_
#define _FREE_LATER_H_
/* _init() must be called before use and _exit() once at the end */
int free_later_init(void);
int free_later_exit(void);
/* add a var to the cleanup later list */
void free_later(void *var, void release(void *var));
#endif
#include "hashmap.h"
#include "free_later.h"
/* TODO: make these variables conditionally built for benchmarking */
/* used for testing CAS-retries in tests */
volatile uint32_t hashmap_put_retries = 0, hashmap_put_replace_fail = 0;
volatile uint32_t hashmap_put_head_fail = 0;
volatile uint32_t hashmap_del_fail = 0, hashmap_del_fail_new_head = 0;
static hashmap_kv_t *create_node_with_malloc(void *opaque,
const void *key,
void *value)
{
hashmap_kv_t *next = malloc(sizeof *(next));
next->key = key;
next->value = value;
return next;
}
static void destroy_node_later(void *opaque, hashmap_kv_t *node)
{
/* free all of these later in case other threads are using them */
free_later((void *) node->key, free);
free_later(node->value, opaque);
free_later(node, free);
}
void *hashmap_new(uint32_t n_buckets,
uint8_t cmp(const void *x, const void *y),
uint64_t hash(const void *key))
{
hashmap_t *map = calloc(1, sizeof(hashmap_t));
map->n_buckets = n_buckets;
map->buckets = calloc(n_buckets, sizeof(hashmap_kv_t *));
/* keep local reference of the two utility functions */
map->hash = hash;
map->cmp = cmp;
/* custom memory management hook */
map->opaque = NULL;
map->create_node = create_node_with_malloc;
map->destroy_node = destroy_node_later;
return map;
}
void *hashmap_get(hashmap_t *map, const void *key)
{
/* hash to convert key to a bucket index where value would be stored */
uint32_t index = map->hash(key) % map->n_buckets;
/* walk through the linked list nodes to find any matches */
for (hashmap_kv_t *n = map->buckets[index]; n; n = n->next) {
if (map->cmp(n->key, key) == 0)
return n->value;
}
return NULL; /* no matches found */
}
bool hashmap_put(hashmap_t *map, const void *key, void *value)
{
if (!map)
return NULL;
/* hash to convert key to a bucket index where value would be stored */
uint32_t bucket_index = map->hash(key) % map->n_buckets;
hashmap_kv_t *kv = NULL, *prev = NULL;
/* known head and next entry to add to the list */
hashmap_kv_t *head = NULL, *next = NULL;
while (true) {
/* copy the head of the list before checking entries for equality */
head = map->buckets[bucket_index];
/* find any existing matches to this key */
prev = NULL;
if (head) {
for (kv = head; kv; kv = kv->next) {
if (map->cmp(key, kv->key) == 0)
break;
prev = kv;
}
}
if (kv) { /* if the key exists, update and return it */
if (!next) /* lazy make the next key-value pair to append */
next = map->create_node(map->opaque, key, value);
/* ensure the linked list's existing node chain persists */
next->next = kv->next;
/* CAS-update the reference in the previous node */
if (prev) {
/* replace this link, assuming it has not changed by another
* thread
*/
if (__atomic_compare_exchange(KKKK, &kv, &next, false,
__ATOMIC_SEQ_CST,
__ATOMIC_SEQ_CST)) {
/* this node, key and value are never again used by this */
map->destroy_node(map->opaque, kv);
return true;
}
hashmap_put_replace_fail += 1;
} else { /* no previous link, update the head of the list */
/* set the head of the list to be whatever this node points to
* (NULL or other links)
*/
if (__atomic_compare_exchange(QQQQ, &kv,
&next, false, __ATOMIC_SEQ_CST,
__ATOMIC_SEQ_CST)) {
map->destroy_node(map->opaque, kv);
return true;
}
/* failure means at least one new entry was added, retry the
* whole match/del process
*/
hashmap_put_head_fail += 1;
}
} else { /* if the key does not exist, try adding it */
if (!next) /* make the next key-value pair to append */
next = map->create_node(map->opaque, key, value);
next->next = NULL;
if (head) /* make sure the reference to existing nodes is kept */
next->next = head;
/* prepend the kv-pair or lazy-make the bucket */
if (__atomic_compare_exchange(&map->buckets[bucket_index], &head,
&next, false, __ATOMIC_SEQ_CST,
__ATOMIC_SEQ_CST)) {
__atomic_fetch_add(&map->length, 1, __ATOMIC_SEQ_CST);
return false;
}
/* failure means another thead updated head before this.
* track the CAS failure for tests -- non-atomic to minimize
* thread contention
*/
hashmap_put_retries += 1;
}
}
}
bool hashmap_del(hashmap_t *map, const void *key)
{
if (!map)
return false;
uint32_t bucket_index = map->hash(key) % map->n_buckets;
/* try to find a match, loop in case a delete attempt fails */
while (true) {
hashmap_kv_t *match, *prev = NULL;
for (match = map->buckets[bucket_index]; match; match = match->next) {
if ((*map->cmp)(key, match->key) == 0)
break;
prev = match;
}
/* exit if no match was found */
if (!match)
return false;
/* previous means this not the head but a link in the list */
if (prev) { /* try the delete but fail if another thread did delete */
if (__atomic_compare_exchange(ZZZZ, &match, &match->next,
false, __ATOMIC_SEQ_CST,
__ATOMIC_SEQ_CST)) {
__atomic_fetch_sub(&map->length, 1, __ATOMIC_SEQ_CST);
map->destroy_node(map->opaque, match);
return true;
}
hashmap_del_fail += 1;
} else { /* no previous link means this needs to leave empty bucket */
/* copy the next link in the list (may be NULL) to the head */
if (__atomic_compare_exchange(&map->buckets[bucket_index], &match,
&match->next, false, __ATOMIC_SEQ_CST,
__ATOMIC_SEQ_CST)) {
__atomic_fetch_sub(&map->length, 1, __ATOMIC_SEQ_CST);
map->destroy_node(map->opaque, match);
return true;
}
/* failure means whole match/del process needs another attempt */
hashmap_del_fail_new_head += 1;
}
}
return false;
}
/* Lock-Free Hashmap
*
* This implementation is thread-safe and lock-free. It will perform well as
* long as the initial bucket size is large enough.
*/
#ifndef _HASHMAP_H_
#define _HASHMAP_H_
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
/* links in the linked lists that each bucket uses */
typedef struct hashmap_keyval {
struct hashmap_keyval *next;
const void *key;
void *value;
} hashmap_kv_t;
/* main hashmap struct with buckets of linked lists */
typedef struct {
hashmap_kv_t **buckets;
uint32_t n_buckets;
uint32_t length; /* total count of entries */
/* pointer to the hash and comparison functions */
uint64_t (*hash)(const void *key);
uint8_t (*cmp)(const void *x, const void *y);
/* custom memory management of internal linked lists */
void *opaque;
hashmap_kv_t *(*create_node)(void *opaque, const void *key, void *data);
void (*destroy_node)(void *opaque, hashmap_kv_t *node);
} hashmap_t;
/* Create and initialize a new hashmap */
void *hashmap_new(uint32_t hint,
uint8_t cmp(const void *x, const void *y),
uint64_t hash(const void *key));
/* Return a value mapped to key or NULL, if no entry exists for the given */
void *hashmap_get(hashmap_t *map, const void *key);
/* Put the given key-value pair in the map.
* @return true if an existing matching key was replaced.
*/
bool hashmap_put(hashmap_t *map, const void *key, void *value);
/* Remove the given key-value pair in the map.
* @return true if a key was found.
* This operation is guaranteed to return true just once, if multiple threads
* are attempting to delete the same key.
*/
bool hashmap_del(hashmap_t *map, const void *key);
#endif
.PHONY: all clean
TARGET = test-hashmap
all: $(TARGET)
include common.mk
CFLAGS = -I.
CFLAGS += -O2 -g
CFLAGS += -std=gnu11 -Wall
LDFLAGS = -lpthread
# standard build rules
.SUFFIXES: .o .c
.c.o:
$(VECHO) " CC\t$@\n"
$(Q)$(CC) -o $@ $(CFLAGS) -c -MMD -MF $@.d $<
OBJS = \
free_later.o \
hashmap.o \
test-hashmap.o
deps += $(OBJS:%.o=%.o.d)
$(TARGET): $(OBJS)
$(VECHO) " LD\t$@\n"
$(Q)$(CC) -o $@ $^ $(LDFLAGS)
check: $(TARGET)
./$^
clean:
$(VECHO) " Cleaning...\n"
$(Q)$(RM) $(TARGET) $(OBJS) $(deps)
-include $(deps)
#include <pthread.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include "free_later.h"
#include "hashmap.h"
/* global hash map */
static hashmap_t *map = NULL;
/* how many threads should run in parallel */
#define N_THREADS 32
/* how many times the work loop should repeat */
#define N_LOOPS 100
/* state for the threads */
static pthread_t threads[N_THREADS];
/* state for the threads that test deletes */
static pthread_t threads_del[N_THREADS * 2];
static uint32_t MAX_VAL_PLUS_ONE = N_THREADS * N_LOOPS + 1;
extern volatile uint32_t hashmap_del_fail, hashmap_del_fail_new_head;
extern volatile uint32_t hashmap_put_retries, hashmap_put_replace_fail;
extern volatile uint32_t hashmap_put_head_fail;
static uint8_t cmp_uint32(const void *x, const void *y)
{
uint32_t xi = *(uint32_t *) x, yi = *(uint32_t *) y;
if (xi > yi)
return -1;
if (xi < yi)
return 1;
return 0;
}
static uint64_t hash_uint32(const void *key)
{
return *(uint32_t *) key;
}
/* Simulates work that is quick and uses the hashtable once per loop */
static void *add_vals(void *args)
{
int *offset = args;
for (int j = 0; j < N_LOOPS; j++) {
int *val = malloc(sizeof(int));
*val = (*offset * N_LOOPS) + j;
hashmap_put(map, val, val);
}
return NULL;
}
bool mt_add_vals(void)
{
for (int i = 0; i < N_THREADS; i++) {
int *offset = malloc(sizeof(int));
*offset = i;
if (pthread_create(&threads[i], NULL, add_vals, offset) != 0) {
printf("Failed to create thread %d\n", i);
exit(1);
}
}
// wait for work to finish
for (int i = 0; i < N_THREADS; i++) {
if (pthread_join(threads[i], NULL) != 0) {
printf("Failed to join thread %d\n", i);
exit(1);
}
}
return true;
}
/* add a value over and over to test the del functionality */
void *add_val(void *args)
{
for (int j = 0; j < N_LOOPS; j++)
hashmap_put(map, &MAX_VAL_PLUS_ONE, &MAX_VAL_PLUS_ONE);
return NULL;
}
static void *del_val(void *args)
{
for (int j = 0; j < N_LOOPS; j++)
hashmap_del(map, &MAX_VAL_PLUS_ONE);
return NULL;
}
bool mt_del_vals(void)
{
for (int i = 0; i < N_THREADS; i++) {
if (pthread_create(&threads_del[i], NULL, add_val, NULL) != 0) {
printf("Failed to create thread %d\n", i);
exit(1);
}
if (pthread_create(&threads_del[N_THREADS + i], NULL, del_val, NULL)) {
printf("Failed to create thread %d\n", i);
exit(1);
}
}
// also add normal numbers to ensure they aren't clobbered
mt_add_vals();
// wait for work to finish
for (int i = 0; i < N_THREADS * 2; i++) {
if (pthread_join(threads_del[i], NULL) != 0) {
printf("Failed to join thread %d\n", i);
exit(1);
}
}
return true;
}
bool test_add()
{
map = hashmap_new(10, cmp_uint32, hash_uint32);
int loops = 0;
while (hashmap_put_retries == 0) {
loops += 1;
if (!mt_add_vals()) {
printf("Error. Failed to add values!\n");
return false;
}
/* check all the list entries */
uint32_t TOTAL = N_THREADS * N_LOOPS;
uint32_t found = 0;
for (uint32_t i = 0; i < TOTAL; i++) {
uint32_t *v = (uint32_t *) hashmap_get(map, &i);
if (v && *v == i) {
found++;
} else {
printf("Cound not find %u in the map\n", i);
}
}
if (found == TOTAL) {
printf(
"Loop %d. All values found. hashmap_put_retries=%u, "
"hashmap_put_head_fail=%u, hashmap_put_replace_fail=%u\n",
loops, hashmap_put_retries, hashmap_put_head_fail,
hashmap_put_replace_fail);
} else {
printf("Found %u of %u values. Where are the missing?", found,
TOTAL);
}
}
printf("Done. Loops=%d\n", loops);
return true;
}
bool test_del()
{
/* keep looping until a CAS retry was needed by hashmap_del */
uint32_t loops = 0;
/* make sure test counters are zeroed */
hashmap_del_fail = hashmap_del_fail_new_head = 0;
while (hashmap_del_fail == 0 || hashmap_del_fail_new_head == 0) {
map = hashmap_new(10, cmp_uint32, hash_uint32);
/* multi-threaded add values */
if (!mt_del_vals()) {
printf("test_del() is failing. Can't complete mt_del_vals()");
return false;
}
loops++;
// check all the list entries
uint32_t TOTAL = N_THREADS * N_LOOPS, found = 0;
for (uint32_t i = 0; i < TOTAL; i++) {
uint32_t *v = hashmap_get(map, &i);
if (v && *v == i) {
found++;
} else {
printf("Cound not find %u in the hashmap\n", i);
}
}
if (found != TOTAL) {
printf("test_del() is failing. Not all values found!?");
return false;
}
}
printf("Done. Needed %u loops\n", loops);
return true;
}
int main()
{
free_later_init();
if (!test_add()) {
printf("Failed to run multi-threaded addition test.");
return 1;
}
if (!test_del()) {
printf("Failed to run multi-threaded deletion test.");
return 2;
}
free_later_exit();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment