Last active
January 9, 2022 20:15
-
-
Save iglosiggio/43e2488bba20793ef4692ca950004fe2 to your computer and use it in GitHub Desktop.
A dumb trie implementation with addition over an arbitrary range.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdint.h> | |
#include <limits.h> | |
#include <stdlib.h> | |
#include <assert.h> | |
#include <stdio.h> | |
typedef uint32_t trie_key_t; | |
typedef float trie_value_t; | |
#define BRANCHING_FACTOR_LOG2 (CHAR_BIT) | |
#define BRANCHING_FACTOR (1ll << BRANCHING_FACTOR_LOG2) | |
#define KEY_SIZE ((sizeof(trie_key_t) * CHAR_BIT + BRANCHING_FACTOR_LOG2 - 1) / BRANCHING_FACTOR_LOG2) | |
#define KEY_BITS (sizeof(trie_key_t) * CHAR_BIT) | |
#define KEY_MAX (1ll << KEY_BITS) - 1 | |
#define KEY_BITSHIFT (KEY_BITS - BRANCHING_FACTOR_LOG2) | |
struct node { | |
trie_value_t value; | |
enum { LEAF_NODE, REGULAR_NODE } type; | |
struct node* children[]; | |
}; | |
struct node* trie_create() { | |
struct node* result = malloc(sizeof *result + sizeof(struct node*[BRANCHING_FACTOR])); | |
result->value = 0; | |
result->type = REGULAR_NODE; | |
for (size_t i = 0; i < BRANCHING_FACTOR; i++) | |
result->children[i] = NULL; | |
fprintf(stderr, "trie_create: new regular node, trie(%x)\n", result); | |
return result; | |
} | |
struct node* trie_create_leaf() { | |
struct node* result = malloc(sizeof *result); | |
result->value = 0; | |
result->type = LEAF_NODE; | |
fprintf(stderr, "trie_create: new leaf node, trie(%x)\n", result); | |
return result; | |
} | |
void trie_free(struct node* node) { | |
if (node == NULL) return; | |
if (node->type == REGULAR_NODE) | |
for (size_t i = 0; i < BRANCHING_FACTOR; i++) | |
trie_free(node->children[i]); | |
free(node); | |
} | |
trie_value_t* trie_put(struct node* node, trie_key_t idx, trie_value_t value) { | |
fprintf(stderr, "trie(%x): Inserting %f at index %u\n", node, value, idx); | |
struct node* path[KEY_SIZE + 1] = { node }; | |
for (size_t i = 0; i < KEY_SIZE; i++) { | |
assert("Traversal only goes through regular nodes" && path[i]->type == REGULAR_NODE); | |
size_t children_idx = (idx >> KEY_BITSHIFT) & 0xFF; | |
idx <<= BRANCHING_FACTOR_LOG2; | |
struct node** node_ptr = &path[i]->children[children_idx]; | |
if (*node_ptr == NULL) | |
// Create either a regular node or a leaf node | |
*node_ptr = i < (KEY_SIZE - 1) ? trie_create() : trie_create_leaf(); | |
path[i+1] = *node_ptr; | |
} | |
assert("Last node is always a leaf" && path[KEY_SIZE]->type == LEAF_NODE); | |
float change = value - path[KEY_SIZE]->value; | |
// Update the memoization | |
for (size_t i = 0; i < KEY_SIZE + 1; i++) | |
path[i]->value += change; | |
return &path[KEY_SIZE]->value; | |
} | |
trie_value_t* trie_get(struct node* node, trie_key_t idx) { | |
for (size_t i = 0; i < KEY_SIZE; i++) { | |
assert("Traversal only goes through regular nodes" && node->type == REGULAR_NODE); | |
size_t children_idx = (idx >> KEY_BITSHIFT) & -1llu; | |
idx <<= BRANCHING_FACTOR_LOG2; | |
struct node* node_ptr = node->children[children_idx]; | |
if (node_ptr == NULL) | |
return NULL; | |
node = node_ptr; | |
} | |
assert("Last node is always a leaf" && node->type == LEAF_NODE); | |
return &node->value; | |
} | |
static | |
trie_value_t trie_add_helper( | |
struct node* node, | |
size_t node_size, trie_key_t node_from, trie_key_t node_to, | |
trie_key_t idx_from, trie_key_t idx_to | |
) { | |
// If we don't have a node just return zero | |
if (node == NULL) return 0; | |
// If this node is inside we will just use that | |
if (node_from >= idx_from && node_to <= idx_to) return node->value; | |
// If this node is a leaf then just return zero | |
if (node_size == 1) return 0; | |
// If this node is completely outside return zero | |
if (node_to < idx_from && node_from > idx_from) return 0; | |
// Else go through each child | |
fprintf(stderr, "traversing trie(%x) with size %llu from %u to %u\n", node, node_size, node_from, node_to); | |
trie_value_t result = 0; | |
size_t children_node_size = node_size / BRANCHING_FACTOR; | |
for (size_t i = 0; i < BRANCHING_FACTOR; i++) | |
result += trie_add_helper( | |
node->children[i], | |
children_node_size, node_from + i * children_node_size, node_from + (i + 1) * children_node_size - 1, | |
idx_from, idx_to | |
); | |
return result; | |
} | |
trie_value_t trie_add(struct node* node, trie_key_t idx_from, trie_key_t idx_to) { | |
fprintf(stderr, "trie(%x) adding from %u to %u\n", node, idx_from, idx_to); | |
return trie_add_helper(node, KEY_MAX + 1, 0, KEY_MAX, idx_from, idx_to); | |
} | |
static | |
void trie_print_graphviz_helper(struct node* node, trie_key_t prefix) { | |
fprintf( | |
stderr, " \"%x\" [shape=%s, label=\"prefix=%llx\\nvalue=%f\"];\n", | |
node, node->type == REGULAR_NODE ? "rectangle" : "ellipse", prefix, node->value, node | |
); | |
if (node->type == REGULAR_NODE) | |
for (size_t i = 0; i < BRANCHING_FACTOR; i++) { | |
struct node* child = node->children[i]; | |
if (child == NULL) continue; | |
fprintf(stderr, " \"%x\" -> \"%x\"\n", node, child); | |
trie_print_graphviz_helper(child, (prefix << BRANCHING_FACTOR_LOG2) + i); | |
} | |
} | |
void trie_print_graphviz(struct node* node) { | |
fprintf(stderr, "digraph {\n"); | |
trie_print_graphviz_helper(node, 0); | |
fprintf(stderr, "}\n"); | |
} | |
int main() { | |
struct node* trie = trie_create(); | |
float* zero_point_five = trie_put(trie, 0xFFFF00, 0.5); | |
float* one_point_five = trie_put(trie, 0xFFFF01, 1.5); | |
float* one_two_three_four_point_five = trie_put(trie, 99, 1234.5); | |
assert(zero_point_five != NULL && *zero_point_five == 0.5); | |
assert(one_point_five != NULL && *one_point_five == 1.5); | |
assert(one_two_three_four_point_five != NULL && *one_two_three_four_point_five == 1234.5); | |
assert(trie_get(trie, 0xFFFF00) != NULL); | |
assert(trie_get(trie, 0xFFFF01) != NULL); | |
assert(trie_get(trie, 99) != NULL); | |
assert(*trie_get(trie, 0xFFFF00) == 0.5); | |
assert(*trie_get(trie, 0xFFFF01) == 1.5); | |
assert(*trie_get(trie, 99) == 1234.5); | |
assert(trie_get(trie, 100) == NULL); | |
float* seven = trie_put(trie, 99, 7); | |
assert(seven != NULL && *seven == 7); | |
assert(trie_get(trie, 99) != NULL); | |
assert(*trie_get(trie, 99) == 7); | |
fprintf(stderr, "%f\n", trie_add(trie, 0, 0xFFFF01)); | |
assert(trie_add(trie, 98, 0xFFFF01) == 0.5 + 1.5 + 7); | |
trie_print_graphviz(trie); | |
trie_free(trie); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment