Skip to content

Instantly share code, notes, and snippets.

@Aszarsha
Last active April 16, 2020 16:05
Show Gist options
  • Save Aszarsha/10324975 to your computer and use it in GitHub Desktop.
Save Aszarsha/10324975 to your computer and use it in GitHub Desktop.
AVL tree (corr for Algo3 students)
#include "avl.h"
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
static inline int max( int a, int b ) {
return a < b ? b : a;
}
//===== AVL Node =====//
typedef struct avl_node_t {
struct avl_node_t * left;
struct avl_node_t * right;
int height;
avl_object value;
} * avl_node;
static inline avl_node avl_node_new( avl_object value ) {
avl_node n = (avl_node)malloc( sizeof(*n) );
n->left = NULL;
n->right = NULL;
n->height = 0;
n->value = value;
return n;
}
static void avl_node_delete( avl_node node, avl_del_func delf ) {
if ( node ) {
avl_node_delete( node->left , delf );
delf( node->value );
avl_node right = node->right;
free( node );
avl_node_delete( right, delf ); // terminal
}
}
static avl_node avl_node_find( avl_node node
, avl_key_func keyf
, avl_comp_func compf
, avl_key key
) {
if ( !node ) { return NULL; }
int cmp = compf( key, keyf( node->value ) );
if ( cmp > 0 ) { return avl_node_find( node->right, keyf, compf, key ); }
else if ( cmp < 0 ) { return avl_node_find( node->left , keyf, compf, key ); }
return node;
}
static inline int avl_node_height( avl_node node ) {
return node ? node->height : -1;
}
static inline avl_node avl_node_left_rotate( avl_node node ) {
avl_node right = node->right;
node->right = right->left;
right->left = node;
node->height -= 1;
right->height += 1;
return right;
}
static inline avl_node avl_node_right_rotate( avl_node node ) {
avl_node left = node->left;
node->left = left->right;
left->right = node;
node->height -= 1;
left->height += 1;
return left;
}
static inline int avl_node_balance( avl_node node ) {
int height_left = avl_node_height( node->left );
int height_right = avl_node_height( node->right );
return height_left - height_right;
}
static avl_node avl_node_rebalance( avl_node node ) {
int balance = avl_node_balance( node );
avl_node result_node = node;
if ( balance >= 2 ) {
if ( avl_node_balance( node->left ) <= -1 ) {
node->left = avl_node_left_rotate( node->left );
}
result_node = avl_node_right_rotate( node );
} else if ( balance <= -2 ) {
if ( avl_node_balance( node->right ) >= 1 ) {
node->right = avl_node_right_rotate( node->right );
}
result_node = avl_node_left_rotate( node );
}
result_node->height = 1 + max( avl_node_height( result_node->left )
, avl_node_height( result_node->right )
);
return result_node;
}
static avl_node avl_node_insert( avl_node node
, avl_key_func keyf
, avl_comp_func compf
, avl_object value
) {
if ( !node ) { return avl_node_new( value ); }
int cmp = compf( keyf( value ), keyf( node->value ) );
if ( cmp > 0 ) {
avl_node new_right = avl_node_insert( node->right, keyf, compf, value );
if ( new_right != NULL ) {
node->right = new_right;
return avl_node_rebalance( node );
}
} else if ( cmp < 0 ) {
avl_node new_left = avl_node_insert( node->left , keyf, compf, value );
if ( new_left != NULL ) {
node->left = new_left;
return avl_node_rebalance( node );
}
}
return NULL;
}
typedef struct remove_return_t {
avl_node new_root;
avl_object removed_object;
} remove_return;
static remove_return avl_node_remove_successor( avl_node current, avl_node root ) {
remove_return ret;
if ( !current->left ) {
ret.removed_object = root->value;
root->value = current->value;
ret.new_root = current->right;
free( current );
return ret;
}
ret = avl_node_remove_successor( current->left, root );
current->left = ret.new_root;
current->height = 1 + max( avl_node_height( current->left )
, avl_node_height( current->right )
);
ret.new_root = avl_node_rebalance( current );
return ret;
}
static remove_return avl_node_remove_predecessor( avl_node current, avl_node root ) {
remove_return ret;
if ( !current->right ) {
ret.removed_object = root->value;
root->value = current->value;
ret.new_root = current->left;
free( current );
return ret;
}
ret = avl_node_remove_predecessor( current->right, root );
current->right = ret.new_root;
current->height = 1 + max( avl_node_height( current->left )
, avl_node_height( current->right )
);
ret.new_root = avl_node_rebalance( current );
return ret;
}
static remove_return avl_node_remove( avl_node node
, avl_key_func keyf
, avl_comp_func compf
, avl_key key
) {
remove_return ret = { NULL, NULL };
if ( !node ) { return ret; }
int cmp = compf( key, keyf( node->value ) );
if ( cmp > 0 ) {
ret = avl_node_remove( node->right, keyf, compf, key );
node->right = ret.new_root;
} else if ( cmp < 0 ) {
ret = avl_node_remove( node->left, keyf, compf, key );
node->left = ret.new_root;
} else {
ret.removed_object = node->value;
avl_node left = node->left;
avl_node right = node->right;
if ( !left || !right ) {
free( node );
ret.new_root = left ? left : right ? right : NULL;
return ret;
} else {
if ( avl_node_balance( node ) < 0 ) {
ret = avl_node_remove_successor ( right, node );
node->right = ret.new_root;
} else {
ret = avl_node_remove_predecessor( left , node );
node->left = ret.new_root;
}
}
}
node->height = 1 + max( avl_node_height( node->left )
, avl_node_height( node->right )
);
ret.new_root = avl_node_rebalance( node );
return ret;
}
//===== AVL Tree =====//
struct avl_tree_t {
avl_node root;
avl_key_func keyf;
avl_comp_func compf;
};
avl_tree avl_tree_new( avl_key_func keyf, avl_comp_func compf ) {
avl_tree t = (avl_tree)malloc( sizeof(*t) );
t->root = NULL;
t->keyf = keyf;
t->compf = compf;
return t;
}
void avl_tree_delete( avl_tree tree, avl_del_func delf ) {
assert( tree != NULL );
avl_node_delete( tree->root, delf );
free( tree );
}
bool avl_tree_insert( avl_tree tree, avl_object value ) {
assert( tree != NULL );
avl_node node = avl_node_insert( tree->root, tree->keyf, tree->compf, value );
return node ? (tree->root = node, true) : false;
}
avl_object avl_tree_find( avl_tree tree, avl_key key ) {
assert( tree != NULL );
avl_node node = avl_node_find( tree->root, tree->keyf, tree->compf, key );
return node ? node->value : NULL;
}
avl_object avl_tree_remove( avl_tree tree, avl_key key ) {
assert( tree != NULL );
remove_return ret = avl_node_remove( tree->root, tree->keyf, tree->compf, key );
tree->root = ret.new_root;
return ret.removed_object;
}
#ifdef TESTS
//===== Tests =====//
typedef struct test_object_t {
int index;
char character;
} * test_object;
static test_object test_object_new( int index, char character ) {
test_object to = (test_object)malloc( sizeof(*to) );
to->index = index;
to->character = character;
return to;
}
static void test_object_delete( avl_object to ) {
free( (test_object)to );
}
static avl_key test_object_key( avl_object to ) {
return &((test_object)to)->index;
}
static int int_key_compare( avl_key a, avl_key b ) {
return *(int *)a - *(int *)b;
}
static void test_object_printer( test_object to ) {
printf( "\"%d: %c\"", to->index, to->character );
}
static void avl_node_print_list( avl_node node ) {
if ( node ) {
printf( "(" );
avl_node_print_list( node->left );
printf( " " );
test_object_printer( node->value );
printf( " " );
avl_node_print_list( node->right );
printf( ")" );
}
}
static void avl_tree_print_list( avl_tree tree ) {
avl_node_print_list( tree->root );
printf( "\n" );
}
static void test_insert( avl_tree tree, int index, char character ) {
test_object value = test_object_new( index, character );
bool success = avl_tree_insert( tree, value );
if ( !success ) {
printf( "-- Failed to insert " );
test_object_printer( value );
printf( "\n : " );
test_object_delete( value );
} else {
printf( "++ Successfully inserted " );
test_object_printer( value );
printf( "\n : " );
}
avl_tree_print_list( tree );
}
static void test_find( avl_tree tree, int key ) {
avl_object res = avl_tree_find( tree, &key );
if ( !res ) {
printf( "-- Failed to find %d\n : ", key );
} else {
printf( "++ Successfully found " );
test_object_printer( res );
printf( "\n : " );
}
avl_tree_print_list( tree );
}
static void test_remove( avl_tree tree, int key ) {
avl_object res = avl_tree_remove( tree, &key );
if ( !res ) {
printf( "-- Failed to remove %d\n : ", key );
} else {
printf( "++ Successfully removed " );
test_object_printer( res );
printf( "\n : " );
test_object_delete( res );
}
avl_tree_print_list( tree );
}
int main( int argc, char * argv[] ) {
avl_tree tree = avl_tree_new( &test_object_key, &int_key_compare );
test_insert( tree, 1, 'a'+0 );
test_insert( tree, 4, 'a'+3 );
test_insert( tree, 5, 'a'+4 );
test_insert( tree, 1, 'v' );
test_insert( tree, 3, 'a'+2 );
test_insert( tree, 6, 'a'+5 );
test_insert( tree, 2, 'a'+1 );
test_insert( tree, 3, 'w' );
test_find( tree, 2 );
test_find( tree, 4 );
test_find( tree, 7 );
test_remove( tree, 5 );
test_remove( tree, 6 );
test_remove( tree, 2 );
test_remove( tree, 7 );
test_find( tree, 2 );
test_find( tree, 4 );
test_find( tree, 0 );
avl_tree_delete( tree, &test_object_delete );
return EXIT_SUCCESS;
}
#endif
#ifndef AVL_TREE_H
#define AVL_TREE_H
#include <stdbool.h>
typedef void * avl_key;
typedef void * avl_object;
typedef void (*avl_del_func)( avl_object );
typedef avl_key (*avl_key_func) ( avl_object );
typedef int (*avl_comp_func)( avl_key, avl_key );
typedef struct avl_tree_t * avl_tree;
avl_tree avl_tree_new( avl_key_func keyf, avl_comp_func compf );
void avl_tree_delete( avl_tree t, avl_del_func delf );
bool avl_tree_insert( avl_tree t, avl_object value );
avl_object avl_tree_find( avl_tree t, avl_key key );
avl_object avl_tree_remove( avl_tree t, avl_key key );
#endif // AVL_TREE_H
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment