Skip to content

Instantly share code, notes, and snippets.

@iglosiggio
Last active January 9, 2022 20:15
Show Gist options
  • Save iglosiggio/43e2488bba20793ef4692ca950004fe2 to your computer and use it in GitHub Desktop.
Save iglosiggio/43e2488bba20793ef4692ca950004fe2 to your computer and use it in GitHub Desktop.
A dumb trie implementation with addition over an arbitrary range.
#include <stdint.h>
#include <limits.h>
#include <stdlib.h>
#include <assert.h>
#include <stdio.h>
typedef uint32_t trie_key_t;
typedef float trie_value_t;
#define BRANCHING_FACTOR_LOG2 (CHAR_BIT)
#define BRANCHING_FACTOR (1ll << BRANCHING_FACTOR_LOG2)
#define KEY_SIZE ((sizeof(trie_key_t) * CHAR_BIT + BRANCHING_FACTOR_LOG2 - 1) / BRANCHING_FACTOR_LOG2)
#define KEY_BITS (sizeof(trie_key_t) * CHAR_BIT)
#define KEY_MAX (1ll << KEY_BITS) - 1
#define KEY_BITSHIFT (KEY_BITS - BRANCHING_FACTOR_LOG2)
struct node {
trie_value_t value;
enum { LEAF_NODE, REGULAR_NODE } type;
struct node* children[];
};
struct node* trie_create() {
struct node* result = malloc(sizeof *result + sizeof(struct node*[BRANCHING_FACTOR]));
result->value = 0;
result->type = REGULAR_NODE;
for (size_t i = 0; i < BRANCHING_FACTOR; i++)
result->children[i] = NULL;
fprintf(stderr, "trie_create: new regular node, trie(%x)\n", result);
return result;
}
struct node* trie_create_leaf() {
struct node* result = malloc(sizeof *result);
result->value = 0;
result->type = LEAF_NODE;
fprintf(stderr, "trie_create: new leaf node, trie(%x)\n", result);
return result;
}
void trie_free(struct node* node) {
if (node == NULL) return;
if (node->type == REGULAR_NODE)
for (size_t i = 0; i < BRANCHING_FACTOR; i++)
trie_free(node->children[i]);
free(node);
}
trie_value_t* trie_put(struct node* node, trie_key_t idx, trie_value_t value) {
fprintf(stderr, "trie(%x): Inserting %f at index %u\n", node, value, idx);
struct node* path[KEY_SIZE + 1] = { node };
for (size_t i = 0; i < KEY_SIZE; i++) {
assert("Traversal only goes through regular nodes" && path[i]->type == REGULAR_NODE);
size_t children_idx = (idx >> KEY_BITSHIFT) & 0xFF;
idx <<= BRANCHING_FACTOR_LOG2;
struct node** node_ptr = &path[i]->children[children_idx];
if (*node_ptr == NULL)
// Create either a regular node or a leaf node
*node_ptr = i < (KEY_SIZE - 1) ? trie_create() : trie_create_leaf();
path[i+1] = *node_ptr;
}
assert("Last node is always a leaf" && path[KEY_SIZE]->type == LEAF_NODE);
float change = value - path[KEY_SIZE]->value;
// Update the memoization
for (size_t i = 0; i < KEY_SIZE + 1; i++)
path[i]->value += change;
return &path[KEY_SIZE]->value;
}
trie_value_t* trie_get(struct node* node, trie_key_t idx) {
for (size_t i = 0; i < KEY_SIZE; i++) {
assert("Traversal only goes through regular nodes" && node->type == REGULAR_NODE);
size_t children_idx = (idx >> KEY_BITSHIFT) & -1llu;
idx <<= BRANCHING_FACTOR_LOG2;
struct node* node_ptr = node->children[children_idx];
if (node_ptr == NULL)
return NULL;
node = node_ptr;
}
assert("Last node is always a leaf" && node->type == LEAF_NODE);
return &node->value;
}
static
trie_value_t trie_add_helper(
struct node* node,
size_t node_size, trie_key_t node_from, trie_key_t node_to,
trie_key_t idx_from, trie_key_t idx_to
) {
// If we don't have a node just return zero
if (node == NULL) return 0;
// If this node is inside we will just use that
if (node_from >= idx_from && node_to <= idx_to) return node->value;
// If this node is a leaf then just return zero
if (node_size == 1) return 0;
// If this node is completely outside return zero
if (node_to < idx_from && node_from > idx_from) return 0;
// Else go through each child
fprintf(stderr, "traversing trie(%x) with size %llu from %u to %u\n", node, node_size, node_from, node_to);
trie_value_t result = 0;
size_t children_node_size = node_size / BRANCHING_FACTOR;
for (size_t i = 0; i < BRANCHING_FACTOR; i++)
result += trie_add_helper(
node->children[i],
children_node_size, node_from + i * children_node_size, node_from + (i + 1) * children_node_size - 1,
idx_from, idx_to
);
return result;
}
trie_value_t trie_add(struct node* node, trie_key_t idx_from, trie_key_t idx_to) {
fprintf(stderr, "trie(%x) adding from %u to %u\n", node, idx_from, idx_to);
return trie_add_helper(node, KEY_MAX + 1, 0, KEY_MAX, idx_from, idx_to);
}
static
void trie_print_graphviz_helper(struct node* node, trie_key_t prefix) {
fprintf(
stderr, " \"%x\" [shape=%s, label=\"prefix=%llx\\nvalue=%f\"];\n",
node, node->type == REGULAR_NODE ? "rectangle" : "ellipse", prefix, node->value, node
);
if (node->type == REGULAR_NODE)
for (size_t i = 0; i < BRANCHING_FACTOR; i++) {
struct node* child = node->children[i];
if (child == NULL) continue;
fprintf(stderr, " \"%x\" -> \"%x\"\n", node, child);
trie_print_graphviz_helper(child, (prefix << BRANCHING_FACTOR_LOG2) + i);
}
}
void trie_print_graphviz(struct node* node) {
fprintf(stderr, "digraph {\n");
trie_print_graphviz_helper(node, 0);
fprintf(stderr, "}\n");
}
int main() {
struct node* trie = trie_create();
float* zero_point_five = trie_put(trie, 0xFFFF00, 0.5);
float* one_point_five = trie_put(trie, 0xFFFF01, 1.5);
float* one_two_three_four_point_five = trie_put(trie, 99, 1234.5);
assert(zero_point_five != NULL && *zero_point_five == 0.5);
assert(one_point_five != NULL && *one_point_five == 1.5);
assert(one_two_three_four_point_five != NULL && *one_two_three_four_point_five == 1234.5);
assert(trie_get(trie, 0xFFFF00) != NULL);
assert(trie_get(trie, 0xFFFF01) != NULL);
assert(trie_get(trie, 99) != NULL);
assert(*trie_get(trie, 0xFFFF00) == 0.5);
assert(*trie_get(trie, 0xFFFF01) == 1.5);
assert(*trie_get(trie, 99) == 1234.5);
assert(trie_get(trie, 100) == NULL);
float* seven = trie_put(trie, 99, 7);
assert(seven != NULL && *seven == 7);
assert(trie_get(trie, 99) != NULL);
assert(*trie_get(trie, 99) == 7);
fprintf(stderr, "%f\n", trie_add(trie, 0, 0xFFFF01));
assert(trie_add(trie, 98, 0xFFFF01) == 0.5 + 1.5 + 7);
trie_print_graphviz(trie);
trie_free(trie);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment