Skip to content

Instantly share code, notes, and snippets.

@Aszarsha
Last active April 16, 2020 16:05
Show Gist options
  • Save Aszarsha/10900899 to your computer and use it in GitHub Desktop.
Save Aszarsha/10900899 to your computer and use it in GitHub Desktop.
Splay tree (corr for Algo3 students)
#include "splay.h"
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <limits.h>
//===== Splay Node =====//
typedef struct splay_node_t {
struct splay_node_t * left;
struct splay_node_t * right;
splay_object value;
} * splay_node;
static inline splay_node splay_node_new( splay_node left, splay_node right
, splay_object value
) {
splay_node node = (splay_node)malloc( sizeof(*node) );
node->left = left;
node->right = right;
node->value = value;
return node;
}
static void splay_node_delete( splay_node node, splay_del_func delf ) {
if ( node ) {
splay_node_delete( node->left , delf );
delf( node->value );
splay_node right = node->right;
free( node );
splay_node_delete( right, delf ); // terminal
}
}
static inline splay_node splay_node_left_rotate( splay_node node ) {
splay_node right = node->right;
node->right = right->left;
right->left = node;
return right;
}
static inline splay_node splay_node_right_rotate( splay_node node ) {
splay_node left = node->left;
node->left = left->right;
left->right = node;
return left;
}
static splay_node splay_node_splay( splay_node node
, splay_key_func keyf
, splay_comp_func compf
, splay_key key
) {
splay_key nkey = keyf( node->value );
int ncmp = compf( key, nkey );
splay_node child;
if ( ncmp >= 1 && (child = node->right) ) { // right-?
splay_key ckey = keyf( child->value );
int ccmp = compf( key, ckey );
splay_node grandChild;
if ( ccmp >= 1 && (grandChild = child->right) ) { // right-right
grandChild = child->right = splay_node_splay( grandChild, keyf, compf, key );
//splay_node alpha = n->left;
splay_node beta = child->left;
splay_node gamma = grandChild->left;
//splay_node theta = grandChild->right;
//grandChild->right = theta;
grandChild->left = child;
child->right = gamma;
child->left = node;
node->right = beta;
//n->left = alpha;
return grandChild;
} else if ( ccmp <= -1 && (grandChild = child->left) ) { // right-left
grandChild = child->left = splay_node_splay( grandChild, keyf, compf, key );
//splay_node theta = child->right;
splay_node gamma = grandChild->right;
splay_node beta = grandChild->left;
//splay_node alpha = n->left;
grandChild->left = node;
grandChild->right = child;
//child->right = theta;
child->left = gamma;
node->right = beta;
//n->left = alpha;
return grandChild;
} else { // odd height to node || no grand-child -> closest key found
splay_node new_root = splay_node_left_rotate( node );
return new_root;
}
} else if ( ncmp <= -1 && (child = node->left) ) { // left-?
splay_key ckey = keyf( child->value );
int ccmp = compf( key, ckey );
splay_node grandChild;
if ( ccmp >= 1 && (grandChild = child->right) ) { // left-right
grandChild = child->right = splay_node_splay( grandChild, keyf, compf, key );
//splay_node alpha = c->left;
splay_node beta = grandChild->left;
splay_node gamma = grandChild->right;
//splay_node theta = n->right;
grandChild->left = child;
grandChild->right = node;
//c->left = alpha;
child->right = beta;
node->left = gamma;
//n->right = theta;
return grandChild;
} else if ( ccmp <= -1 && (grandChild = child->left) ) { // left-left
grandChild = child->left = splay_node_splay( grandChild, keyf, compf, key );
//splay_node theta = n->right;
splay_node gamma = child->right;
splay_node beta = grandChild->right;
//splay_node alpha = grandChild->left;
//grandChild->left = alpha;
grandChild->right = child;
child->left = beta;
child->right = node;
node->left = gamma;
//n->right = theta;
return grandChild;
} else { // odd height to node || no grand-child -> closest key found
splay_node new_root = splay_node_right_rotate( node );
return new_root;
}
} // else, no child ==> closest key
return node;
};
/*
static splay_node splay_node_cheap_splay( splay_node node
, splay_key_func keyf
, splay_comp_func compf
, splay_key key
) {
splay_key localkey = keyf( node->value );
int cmp = compf( key, localkey );
if ( cmp > 0) {
node->right = splay_node_cheap_splay( node->right, keyf, compf, key );
return splay_node_left_rotate( node );
} else if ( cmp < 0 ) {
node->left = splay_node_cheap_splay( node->left, keyf, compf, key );
return splay_node_right_rotate( node );
}
return node;
}
*/
//===== Splay Tree =====//
struct splay_tree_t {
splay_node root;
splay_key_func keyf;
splay_comp_func compf;
};
splay_tree splay_tree_new( splay_key_func keyf, splay_comp_func compf ) {
splay_tree tree = (splay_tree)malloc( sizeof(*tree) );
tree->root = NULL;
tree->keyf = keyf;
tree->compf = compf;
return tree;
}
void splay_tree_delete( splay_tree tree, splay_del_func delf ) {
assert( tree != NULL );
splay_node_delete( tree->root, delf );
free( tree );
}
bool splay_tree_insert( splay_tree tree, splay_object object ) {
assert( tree != NULL );
if ( !tree->root ) {
tree->root = splay_node_new( NULL, NULL, object );
return true;
}
splay_key objectkey = tree->keyf( object );
tree->root = splay_node_splay( tree->root, tree->keyf, tree->compf, objectkey );
int cmp = tree->compf( objectkey, tree->keyf( tree->root->value ) );
if ( cmp > 0 ) {
splay_node node = splay_node_new( tree->root, tree->root->right, object );
tree->root->right = NULL;
tree->root = node;
return true;
} else if ( cmp < 0 ) {
splay_node node = splay_node_new( tree->root->left, tree->root, object );
tree->root->left = NULL;
tree->root = node;
return true;
} else {
return false;
}
}
splay_object splay_tree_find( splay_tree tree, splay_key key ) {
assert( tree != NULL );
if ( !tree->root ) { return NULL; }
tree->root = splay_node_splay( tree->root, tree->keyf, tree->compf, key );
if ( tree->compf( tree->keyf( tree->root->value ), key ) == 0 ) {
return tree->root->value;
}
return NULL;
}
static splay_node node_find_predecessor( splay_node node ) {
return !node->right ? node : node_find_predecessor( node->right );
}
splay_object splay_tree_remove( splay_tree tree, splay_key key ) {
assert( tree != NULL );
if ( !tree->root ) { return NULL; }
tree->root = splay_node_splay( tree->root, tree->keyf, tree->compf, key );
if ( tree->compf( key, tree->keyf( tree->root->value ) ) != 0 ) {
return NULL;
}
splay_node root = tree->root;
splay_node left = root->left;
splay_node right = root->right;
splay_object value = tree->root->value;
if ( !left ) {
if ( !right ) {
free( tree->root );
tree->root = NULL;
} else {
tree->root = right;
free( root );
}
} else {
splay_node pred = node_find_predecessor( left );
tree->root = left;
splay_key predkey = tree->keyf( pred->value );
tree->root = splay_node_splay( tree->root, tree->keyf, tree->compf, predkey );
tree->root->right = right;
free( root );
}
return value;
}
/*
splay_tree splay_tree_coupe( splay_tree tree, splay_key key ) {
assert( tree != NULL );
if ( !tree->root ) {
return splay_tree_new( tree->keyf, tree->compf );
}
tree->root = splay_node_splay( tree->root, tree->keyf, tree->compf, key );
splay_key rootkey = tree->keyf( tree->root->value );
splay_tree res = (splay_tree)malloc( sizeof(*res) );
res->keyf = tree->keyf;
res->compf = tree->compf;
if ( tree->compf( key, rootkey ) >= 0 ) {
res->root = tree->root->right;
tree->root->right = NULL;
} else {
res->root = tree->root;
tree->root = tree->root->left;
res->root->left = NULL;
}
return res;
}
static int splay_max_func( splay_key u, splay_key v ) { return 1; }
static int splay_min_func( splay_key u, splay_key v ) { return -1; }
splay_tree splay_tree_union( splay_tree A, splay_tree B ) {
assert( A != NULL && B != NULL );
A->root = splay_node_splay( A->root, A->keyf, splay_max_func, NULL );
B->root = splay_node_splay( B->root, B->keyf, splay_min_func, NULL );
splay_tree res = splay_tree_new( A->keyf, A->compf );
if ( !A->root ) {
res->root = B->root;
} else {
A->root->right = B->root;
res->root = A->root;
}
return res;
}
*/
#ifdef TESTS
//===== Tests =====//
typedef struct test_object_t {
int index;
char character;
} * test_object;
static test_object test_object_new( int index, char character ) {
test_object to = (test_object)malloc( sizeof(*to) );
to->index = index;
to->character = character;
return to;
}
static void test_object_delete( splay_object to ) {
free( (test_object)to );
}
static splay_key test_object_key( splay_object to ) {
return &((test_object)to)->index;
}
static int int_key_compare( splay_key a, splay_key b ) {
return *(int *)a - *(int *)b;
}
static void test_object_printer( test_object to ) {
printf( "\"%d: %c\"", to->index, to->character );
}
static void splay_node_print_list( splay_node node ) {
if ( node ) {
printf( "(" );
splay_node_print_list( node->left );
printf( " " );
test_object_printer( node->value );
printf( " " );
splay_node_print_list( node->right );
printf( ")" );
}
}
static void splay_tree_print_list( splay_tree tree ) {
splay_node_print_list( tree->root );
printf( "\n" );
}
static void test_insert( splay_tree tree, int index, char character ) {
test_object value = test_object_new( index, character );
bool success = splay_tree_insert( tree, value );
if ( !success ) {
printf( "-- Failed to insert " );
test_object_printer( value );
printf( "\n : " );
test_object_delete( value );
} else {
printf( "++ Successfully inserted " );
test_object_printer( value );
printf( "\n : " );
}
splay_tree_print_list( tree );
}
static void test_find( splay_tree tree, int key ) {
splay_object res = splay_tree_find( tree, &key );
if ( !res ) {
printf( "-- Failed to find %d\n : ", key );
} else {
printf( "++ Successfully found " );
test_object_printer( res );
printf( "\n : " );
}
splay_tree_print_list( tree );
}
static void test_remove( splay_tree tree, int key ) {
splay_object res = splay_tree_remove( tree, &key );
if ( !res ) {
printf( "-- Failed to remove %d\n : ", key );
} else {
printf( "++ Successfully removed " );
test_object_printer( res );
printf( "\n : " );
test_object_delete( res );
}
splay_tree_print_list( tree );
}
int main( int argc, char * argv[] ) {
splay_tree tree = splay_tree_new( &test_object_key, &int_key_compare );
test_insert( tree, 1, 'a'+0 );
test_insert( tree, 4, 'a'+3 );
test_insert( tree, 5, 'a'+4 );
test_insert( tree, 1, 'v' );
test_insert( tree, 3, 'a'+2 );
test_insert( tree, 6, 'a'+5 );
test_insert( tree, 2, 'a'+1 );
test_insert( tree, 3, 'w' );
test_find( tree, 2 );
test_find( tree, 4 );
test_find( tree, 7 );
test_remove( tree, 5 );
test_remove( tree, 6 );
test_remove( tree, 2 );
test_remove( tree, 7 );
test_find( tree, 2 );
test_find( tree, 4 );
test_find( tree, 0 );
splay_tree_delete( tree, &test_object_delete );
return EXIT_SUCCESS;
}
#endif
#ifndef SPLAY_TREE_H
#define SPLAY_TREE_H
#include <stdbool.h>
typedef void * splay_key;
typedef void * splay_object;
typedef void (*splay_del_func)( splay_object );
typedef splay_key (*splay_key_func) ( splay_object );
typedef int (*splay_comp_func)( splay_key, splay_key );
typedef struct splay_tree_t * splay_tree;
splay_tree splay_tree_new( splay_key_func keyf, splay_comp_func compf );
void splay_tree_delete( splay_tree t, splay_del_func delf );
bool splay_tree_insert( splay_tree t, splay_object value );
splay_object splay_tree_find( splay_tree t, splay_key key );
splay_object splay_tree_remove( splay_tree t, splay_key key );
#endif // SPLAY_TREE_H
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment