Skip to content

Instantly share code, notes, and snippets.

@denvaar
Created February 12, 2017 02:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save denvaar/62ac0e5636b54f76687cd48be9a86e8e to your computer and use it in GitHub Desktop.
Save denvaar/62ac0e5636b54f76687cd48be9a86e8e to your computer and use it in GitHub Desktop.
Binary Search Tree search, insert, and deletion operations with unit tests.
#define CATCH_CONFIG_MAIN
#include "catch.hpp"
#include <vector>
using namespace std;
struct Node {
int value;
Node *left;
Node *right;
};
/* recursive search */
Node* SearchRecursive(Node* root, int value) {
if (root) {
if (root->value == value) return root;
else if (value <= root->value) return SearchRecursive(root->left, value);
else return SearchRecursive(root->right, value);
}
return NULL;
}
/* iterative search */
Node* SearchIterative(Node* root, int value) {
if (root) {
vector<Node*> to_visit;
to_visit.push_back(root);
while (to_visit.size() > 0) {
Node* next = to_visit.back();
to_visit.pop_back();
if (next->value == value) return next;
else if (next->left && value <= next->value) to_visit.push_back(next->left);
else if (next->right && value > next->value) to_visit.push_back(next->right);
}
}
return NULL;
}
/* recursive insert */
Node* InsertRecursive(Node *root, int value) {
if (!root) {
root = new Node();
root->value = value;
}
else if (value <= root->value) root->left = InsertRecursive(root->left, value);
else root->right = InsertRecursive(root->right, value);
return root;
}
/* iterative insert */
Node* InsertIterative(Node *root, int value) {
if (!root) {
root = new Node();
root->value = value;
return root;
}
vector<Node*> to_visit;
to_visit.push_back(root);
Node *next = root;
while (to_visit.size() > 0) {
next = to_visit.back();
to_visit.pop_back();
if (value <= next->value) {
if (next->left) to_visit.push_back(next->left);
else {
next->left = new Node();
next->left->value = value;
}
}
else {
if (next->right) to_visit.push_back(next->right);
else {
next->right = new Node();
next->right->value = value;
}
}
}
return root;
}
/* find the minimum value in a tree */
Node* MinNode(Node* root) {
while (root->left) {
root = root->left;
}
return root;
}
/* recursive removal */
Node* Remove(Node* root, int value) {
// step 1: find the element to be removed
if (root) {
if (value < root->value) root->left = Remove(root->left, value);
else if (value > root->value) root->right = Remove(root->right, value);
else {
// here we found the element to remove.
// Easy case - element is a leaf node.
if (!root->left && !root->right) {
delete root;
root = NULL;
}
// Easy case - element to delete has a subtree on the right.
else if (!root->left) {
Node* temp = root;
root = root->right;
delete temp;
}
// Easy case - element to delete has a subtree on the left.
else if (!root->right) {
Node* temp = root;
root = root->left;
delete temp;
}
// Tricky case - element to delete has two subtrees.
else {
/* Find either the minimum element in the right subtree
or the maximum element in the left subtree. */
Node* min = MinNode(root->right);
root->value = min->value;
// The problem has now been reduced to an easy case.
root->right = Remove(root->right, min->value);
}
}
}
return root;
}
TEST_CASE("BST operations", "[Insert]") {
REQUIRE(InsertRecursive(NULL, 0)->value == 0);
REQUIRE(InsertRecursive(NULL, 1)->value == 1);
SECTION("Nodes get InsertRecursive'ed into right spot") {
Node *root = NULL;
root = InsertRecursive(root, 5);
REQUIRE(root->value == 5);
root = InsertRecursive(root, 4);
root = InsertRecursive(root, 6);
REQUIRE(root->left->value == 4);
REQUIRE(root->right->value == 6);
root = InsertRecursive(root, 10);
root = InsertRecursive(root, 7);
REQUIRE(root->right->right->value == 10);
REQUIRE(root->right->right->left->value == 7);
}
SECTION("Nodes get InsertIteratively'ed into right spot") {
Node *root = NULL;
root = InsertIterative(root, 5);
REQUIRE(root->value == 5);
root = InsertIterative(root, 4);
root = InsertIterative(root, 6);
REQUIRE(root->left->value == 4);
REQUIRE(root->right->value == 6);
root = InsertIterative(root, 10);
root = InsertIterative(root, 7);
REQUIRE(root->right->right->value == 10);
REQUIRE(root->right->right->left->value == 7);
}
SECTION("Searching") {
Node *root = NULL;
root = InsertIterative(root, 10);
root = InsertIterative(root, 20);
root = InsertIterative(root, 0);
root = InsertRecursive(root, 7);
root = InsertRecursive(root, 78);
REQUIRE(SearchRecursive(root, 10)->value == 10);
REQUIRE(SearchRecursive(root, 20)->value == 20);
REQUIRE(SearchRecursive(root, 0)->value == 0);
REQUIRE(SearchRecursive(root, 78)->value == 78);
REQUIRE(SearchRecursive(root, 7)->value == 7);
REQUIRE(SearchRecursive(root, 50) == NULL);
REQUIRE(SearchRecursive(NULL, 50) == NULL);
REQUIRE(SearchIterative(root, 10)->value == 10);
REQUIRE(SearchIterative(root, 20)->value == 20);
REQUIRE(SearchIterative(root, 0)->value == 0);
REQUIRE(SearchIterative(root, 78)->value == 78);
REQUIRE(SearchIterative(root, 7)->value == 7);
REQUIRE(SearchIterative(root, 50) == NULL);
REQUIRE(SearchIterative(NULL, 50) == NULL);
}
SECTION("Removal") {
Node* root = NULL;
root = InsertIterative(root, 5);
root = InsertIterative(root, 1);
root = InsertIterative(root, 9);
root = InsertIterative(root, 10);
root = InsertIterative(root, 0);
REQUIRE(MinNode(root)->value == 0);
root = Remove(root, 5);
REQUIRE(SearchRecursive(root, 5) == NULL);
root = Remove(root, 0);
REQUIRE(SearchRecursive(root, 0) == NULL);
root = Remove(root, 1);
REQUIRE(SearchRecursive(root, 1) == NULL);
}
}
@denvaar
Copy link
Author

denvaar commented Feb 12, 2017

catch.hpp can be found here https://github.com/philsquared/Catch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment