Skip to content

Instantly share code, notes, and snippets.

@martinkunev
Last active November 28, 2017 15:32
Show Gist options
  • Save martinkunev/720386 to your computer and use it in GitHub Desktop.
Save martinkunev/720386 to your computer and use it in GitHub Desktop.
AVL tree (C implementation)
/*
* Conquest of Levidon
* Copyright (C) 2017 Martin Kunev <martinkunev@gmail.com>
*
* This file is part of Conquest of Levidon.
*
* Conquest of Levidon is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation version 3 of the License.
*
* Conquest of Levidon is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Conquest of Levidon. If not, see <http://www.gnu.org/licenses/>.
*/
#if !defined(avl_type)
# define avl_type int
#endif
// Returns positive number if a is before b, 0 if a == b, negative number if a is after b.
//#define avl_compare(a, a_size, b) (*(int *)(b)->key_data - *(int *)(a)) /* TODO this could overflow/underflow */
#define avl_compare(a, a_size, b) memcmp((b)->key_data, (a), (a_size))
// Callback for updating aggregated values for a node. When called, both children of the passed node are in a consistent state.
#define avl_update(node) (void)0
// TODO make the factor calculation code readable
// TODO implement union, intersection, difference
// TODO test avl_update() with range searches
struct avl
{
size_t count;
struct avl_node
{
struct avl_node *next[2];
const size_t key_size;
avl_type value;
signed char factor;
const unsigned char key_data[];
} *root;
};
// A newly created struct avl must be initialized with zeroes: struct avl avl = {0};
avl_type *avl_get(struct avl *avl, const unsigned char *restrict key_data, size_t key_size);
avl_type *avl_insert(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, avl_type value);
void avl_remove(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, avl_type *value_old);
void avl_iterate(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, int (*callback)(struct avl_node *, void *), void *argument);
void avl_term(struct avl *avl);
/* avl.c */
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
struct avl_node_mutable
{
struct avl_node *next[2];
size_t key_size;
avl_type value;
signed char factor;
unsigned char key_data[];
};
static avl_type *avl_get_internal(struct avl_node *t, const unsigned char *restrict key_data, size_t key_size)
{
while (t)
{
int diff = avl_compare(key_data, key_size, t);
if (!diff) return &t->value;
t = t->next[diff < 0];
}
return 0;
}
avl_type *avl_get(struct avl *avl, const unsigned char *restrict key_data, size_t key_size)
{
return avl_get_internal(avl->root, key_data, key_size);
}
static signed char avl_balance(struct avl_node *restrict *restrict branch)
{
struct avl_node *node = *branch;
unsigned char index = (node->factor < 0);
struct avl_node *child = node->next[index];
signed char factor_sign = (int [2]){1, -1}[index];
assert(node->factor);
// Rotate the longer subtree if necessary.
if (factor_sign * child->factor < 0)
{
struct avl_node *grandchild = child->next[index ^ 1];
node->next[index] = grandchild;
child->next[index ^ 1] = grandchild->next[index];
grandchild->next[index] = child;
// Set the balance factor
child->factor = factor_sign * (factor_sign * grandchild->factor < 0); // TODO fix this code
grandchild->factor = factor_sign * (factor_sign * grandchild->factor > 0) + factor_sign; // TODO fix this code
avl_update(child);
child = grandchild;
}
*branch = child;
node->next[index] = child->next[index ^ 1];
child->next[index ^ 1] = node;
// Set the balance factor
node->factor = factor_sign - child->factor; // TODO fix this code
child->factor = -!child->factor * factor_sign; // TODO fix this code
avl_update(node);
avl_update(child);
// Return height change.
return -!child->factor; // TODO fix this code
}
static avl_type *avl_insert_key(struct avl_node *restrict *restrict branch, const unsigned char *restrict key_data, size_t key_size, avl_type value, int *restrict height_change, int *new)
{
struct avl_node *t = *branch;
if (t) // the node is occupied
{
unsigned char index;
avl_type *result;
int diff = avl_compare(key_data, key_size, t);
if (!diff) // the key is already in the tree
{
// WARNING: Tree of complex data type should do something more here.
return &t->value;
}
index = (diff < 0);
// Insert the node in the subtree. Exit if the height hasn't changed.
result = avl_insert_key(t->next + index, key_data, key_size, value, height_change, new);
if (*height_change)
{
// Calculate the new balance factor. Rebalance the tree if necessary.
t->factor += (int [2]){1, -1}[index];
if ((t->factor < -1) || (t->factor > 1))
*height_change = avl_balance(branch) + 1;
else
{
*height_change = (t->factor != 0);
avl_update(t);
}
}
else if (*new) avl_update(t);
return result;
}
else // the node is vacant
{
struct avl_node_mutable *node = malloc(offsetof(struct avl_node, key_data) + key_size);
if (!node)
return 0;
// WARNING: Tree of complex data type should do something more here.
node->key_size = key_size;
node->value = value;
node->factor = 0;
node->next[0] = node->next[1] = 0;
memcpy(node->key_data, key_data, key_size);
*branch = (struct avl_node *)node;
*height_change = 1;
*new = 1;
return &node->value;
}
}
avl_type *avl_insert(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, avl_type value)
{
int height_change = 0;
int new = 0;
avl_type *result = avl_insert_key(&avl->root, key_data, key_size, value, &height_change, &new);
if (new)
avl->count += 1;
return result;
}
// Calculate the new blance factor after an item removal. Rebalance the tree if necessary.
static signed char avl_remove_factor(struct avl_node **branch, size_t index)
{
struct avl_node *node = *branch;
node->factor -= (int [2]){1, -1}[index];
if ((node->factor < -1) || (node->factor > 1))
return avl_balance(branch);
else
{
avl_update(node);
return -!node->factor; // TODO fix this code
}
}
static int avl_move_closest(struct avl_node **branch, const unsigned char *restrict key_data, size_t key_size, struct avl_node_mutable *restrict position)
{
struct avl_node *t = *branch;
unsigned char index = (avl_compare(key_data, key_size, t) < 0);
struct avl_node **child = t->next + index;
if (*child) // the subtree is not empty
{
// Remove the node from the subtree. Exit if the height hasn't changed.
if (!avl_move_closest(child, key_data, key_size, position))
{
avl_update(t);
return 0;
}
return avl_remove_factor(branch, index);
}
else // the subtree is empty but the closest node has to be moved
{
// Copy node data to its new position.
// WARNING: Tree of complex data type should do something more here.
position->key_size = t->key_size;
memcpy(position->key_data, t->key_data, t->key_size);
position->value = t->value;
// Remove current node
if (t->next[0]) *branch = t->next[0];
else *branch = t->next[1];
// WARNING: Tree of complex data type should do something more here.
free(t);
return -1;
}
}
static signed char avl_remove_key(struct avl_node **branch, const unsigned char *restrict key_data, size_t key_size, avl_type *value_old, int *restrict changed)
{
struct avl_node *t = *branch;
int diff = avl_compare(key_data, key_size, t);
// If this is the key to be removed
if (!diff)
{
*changed = 1;
if (value_old)
*value_old = t->value;
if (t->next[0])
{
if (t->next[1])
{
int height_change;
unsigned char index = (t->factor < 0);
// Replace the current node with the closest by key node in the taller subtree. Exit if the height hasn't changed
// WARNING: Tree of complex data type may need to do something more here.
height_change = avl_move_closest(t->next + index, key_data, key_size, (struct avl_node_mutable *)t);
if (!height_change)
{
avl_update(t);
return 0;
}
return avl_remove_factor(branch, index);
}
else *branch = t->next[0];
}
else *branch = t->next[1];
// WARNING: Tree of complex data type should do something more here
free(t);
return -1;
}
// Find the node in the subtree where it should be
{
unsigned char index = (diff < 0);
struct avl_node **child = t->next + index;
// If such node doesn't exist, there is nothing to remove or balance.
if (!*child) // the subtree is empty
return 0;
// Remove the node from the subtree. Exit if the height hasn't changed
if (!avl_remove_key(child, key_data, key_size, value_old, changed))
{
if (*changed)
avl_update(t);
return 0;
}
return avl_remove_factor(branch, index);
}
}
void avl_remove(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, avl_type *value_old)
{
if (avl->root)
{
int changed = 0;
avl_remove_key(&avl->root, key_data, key_size, value_old, &changed);
if (changed)
avl->count -= 1;
}
}
// Returns whether the iteration has been stopped.
static int iterate(struct avl_node *node, const unsigned char *restrict key_data, size_t key_size, int (*callback)(struct avl_node *, void *), void *argument)
{
if (!node)
return 0;
if (key_data)
{
int diff = avl_compare(key_data, key_size, node);
if (!diff) // key found
{
// Iteration starts here.
return (*callback)(node, argument) || iterate(node->next[1], 0, 0, callback, argument);
}
else if (diff > 0)
{
// Iteration will start in the left subtree.
return iterate(node->next[0], key_data, key_size, callback, argument) || (*callback)(node, argument) || iterate(node->next[1], 0, 0, callback, argument);
}
else
{
// Iteration will start in the right subtree.
return iterate(node->next[1], key_data, key_size, callback, argument);
}
}
// Iteration has started.
return iterate(node->next[0], key_data, key_size, callback, argument) || (*callback)(node, argument) || iterate(node->next[1], key_data, key_size, callback, argument);
}
void avl_iterate(struct avl *avl, const unsigned char *restrict key_data, size_t key_size, int (*callback)(struct avl_node *, void *), void *argument)
{
iterate(avl->root, key_data, key_size, callback, argument);
}
static void avl_term_internal(struct avl_node *node)
{
if (!node)
return;
avl_term_internal(node->next[0]);
avl_term_internal(node->next[1]);
// WARNING: Tree of complex data type should do something more here
free(node);
}
void avl_term(struct avl *avl)
{
avl_term_internal(avl->root);
}
/* tests */
#include <stdarg.h>
#include <stddef.h>
#include <setjmp.h>
#include <cmocka.h>
#define NODES_COUNT 1024
static struct avl tree;
static avl_type values[NODES_COUNT];
static size_t values_count;
/*void print(struct avl_node *t)
{
if (t)
{
if (t->next[0])
{
printf("[");
print(t->next[0]);
printf("]");
}
printf(" %d ", t->key, t->factor);
//printf("%d (%d)", t->key, t->factor);
if (t->next[1])
{
printf("[");
print(t->next[1]);
printf("]");
}
}
}*/
static unsigned check_height(struct avl_node *t)
{
if (t)
{
unsigned a = check_height(t->next[0]);
unsigned b = check_height(t->next[1]);
assert_int_equal(t->factor, (int)a - (int)b);
assert_true(t->factor >= -1);
assert_true(t->factor <= 1);
return ((a >= b) ? a : b) + 1;
}
return 0;
}
static void check_sorted(struct avl_node *t)
{
if (t->next[0])
{
check_sorted(t->next[0]);
assert_true(t->next[0]->value < t->value);
}
if (t->next[1])
{
assert_true(t->value < t->next[1]->value);
check_sorted(t->next[1]);
}
}
#include <netinet/in.h>
static void test_avl_insert(void **state)
{
for(size_t i = 0; i < NODES_COUNT; i += 1)
{
avl_type value = random() % 4096, *result;
int key = htonl(value);
values[values_count++] = value;
result = avl_insert(&tree, (const unsigned char *)&key, sizeof(key), value);
assert_non_null(result);
assert_int_equal(*result, value);
check_height(tree.root);
}
}
static void test_avl_get(void **state)
{
for(size_t i = 0; i < NODES_COUNT; i += 1)
{
int key = htonl(values[i]);
avl_type *value = avl_get(&tree, (const unsigned char *)&key, sizeof(key));
assert_true(value);
assert_int_equal(*value, values[i]);
}
int key = htonl(4096);
assert_false(avl_get(&tree, (const unsigned char *)&key, sizeof(key)));
}
struct it
{
avl_type min, max;
avl_type last;
int done;
};
static int callback(struct avl_node *node, void *argument)
{
struct it *it = argument;
assert_true(node->value > it->last);
assert_true(node->value >= it->min);
assert_false(it->done);
it->last = node->value;
it->done = (node->value > it->max);
return it->done;
}
static void test_avl_iterate(void **state)
{
struct it it = {.min = 248, .max = 958, .last = -1};
avl_iterate(&tree, (const unsigned char *)&it.min, sizeof(it.min), callback, &it);
}
static void test_avl_remove(void **state)
{
int i;
for(i = 0; i < 2048; i += 1)
{
int key = htonl(i);
avl_type value;
avl_remove(&tree, (const unsigned char *)&key, sizeof(key), &value);
check_height(tree.root);
// TODO some assertion about the value
}
for(i = 0; i < NODES_COUNT; i += 1)
{
int key = htonl(values[i]);
if (values[i] < 2048)
assert_false(avl_get(&tree, (const unsigned char *)&key, sizeof(key)));
else
assert_true(avl_get(&tree, (const unsigned char *)&key, sizeof(key)));
}
}
static void test_avl_sorted(void **state)
{
check_sorted(tree.root);
}
int main(int argc, char *argv[])
{
const struct CMUnitTest tests[] =
{
cmocka_unit_test(test_avl_insert),
cmocka_unit_test(test_avl_get),
cmocka_unit_test(test_avl_iterate),
cmocka_unit_test(test_avl_remove),
cmocka_unit_test(test_avl_sorted),
};
int status;
srandom(1);
tree = (struct avl){0};
status = cmocka_run_group_tests(tests, 0, 0);
avl_term(&tree);
return status;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment