Skip to content

Instantly share code, notes, and snippets.

@santa4nt
Last active August 29, 2015 14:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save santa4nt/5ba0c75836294a8bf315 to your computer and use it in GitHub Desktop.
Save santa4nt/5ba0c75836294a8bf315 to your computer and use it in GitHub Desktop.
In-order traversal on a BST without recursion.
#include <stack>
#include <vector>
#include <memory>
#include <functional>
#include <exception>
#include <iostream>
#include <cassert>
template <typename T>
class TreeNode
{
public:
using TreeNodePtr = std::shared_ptr<TreeNode<T>>;
TreeNode(T data, TreeNodePtr left = nullptr, TreeNodePtr right = nullptr);
T data() const;
const TreeNodePtr left() const;
const TreeNodePtr right() const;
int rank() const;
/**
* Given a root to a binary search tree, insert a new value.
* If successful, returns a pointer to the newly created node.
* Otherwise, nullptr is returned.
*/
static TreeNodePtr insert_value(TreeNodePtr root, T data);
/**
* Given an (unordered) array of values, create a binary search tree and return
* a pointer to its root.
*/
static TreeNodePtr create_from_array(const std::vector<T> &arr);
/**
* Given a root to a binary search tree, perform an in-order traversal
* of all its nodes. Additionally, an optional callback object can be given
* to be notified when each node is visited, and when all nodes have been visited.
*
* Return true if all nodes were visited, false if the callback forced an early stop.
*/
template <typename VisitFunctor, typename EndFunctor>
static bool traverse_in_order(const TreeNodePtr root, VisitFunctor *vf = nullptr, EndFunctor *ef = nullptr);
/**
* Using traverse_in_order() above, visit each node and print its value to stdout.
*/
static void print_in_order(const TreeNodePtr root);
/**
* Given a root to a binary search tree, find the k-th smallest value.
* May throw out_of_range exception if k > number of nodes in the tree.
* e.g. if k == 1, the returned value will be the minimum, extreme-left value.
*/
static T find_kth_smallest(const TreeNodePtr root, unsigned int k);
/**
* Given a root to a binary search tree, unconditionally increment the "rank"
* of all of its nodes.
*/
static void increment_rank(TreeNodePtr root);
private:
T _data;
TreeNodePtr _left;
TreeNodePtr _right;
int _rank; // the rank of a node is the number of nodes "to the left" of this node;
// note that this includes this node's parent node and the latter's left subtree's
};
template <typename T>
TreeNode<T>::TreeNode(T data, TreeNodePtr left, TreeNodePtr right)
: _data(data)
, _left(left)
, _right(right)
, _rank(0)
{
}
template <typename T>
T TreeNode<T>::data() const
{
return _data;
}
template <typename T>
const typename TreeNode<T>::TreeNodePtr TreeNode<T>::left() const
{
return _left;
}
template <typename T>
const typename TreeNode<T>::TreeNodePtr TreeNode<T>::right() const
{
return _right;
}
template <typename T>
int TreeNode<T>::rank() const
{
return _rank;
}
template <typename T>
typename TreeNode<T>::TreeNodePtr TreeNode<T>::insert_value(TreeNodePtr root, T data)
{
TreeNodePtr newnode = nullptr;
if (data == root->data())
{
// duplicate value detected!
return nullptr;
}
else if (data < root->data())
{
if (root->left() == nullptr)
{
root->_left = std::make_shared<TreeNode<T>>(data);
newnode = root->left();
// if a new node is created as a left child of this node, then its rank
// is the same as the parent's current rank
newnode->_rank = root->rank();
}
else
{
newnode = insert_value(root->left(), data);
}
// in either case, whether we inserted a new node directly or recursively did so,
// this node's rank is now incremented because we inserted a node _somewhere_ to the left
root->_rank++;
// if this node has a right subtree, all the nodes there have their ranks incremented, too
if (root->right() != nullptr)
{
increment_rank(root->right());
}
return newnode;
}
else
{
// we're adding a new node in the right subtree of this node;
// as such, this node's rank itself does not change
if (root->right() == nullptr)
{
root->_right = std::make_shared<TreeNode<T>>(data);
newnode = root->right();
// if a new node is created as a right child of this node, then its rank is
// the rank of this node, plus one
newnode->_rank = root->rank() + 1;
}
else
{
newnode = insert_value(root->right(), data);
}
return newnode;
}
}
template <typename T>
void TreeNode<T>::increment_rank(TreeNodePtr root)
{
if (root == nullptr)
return;
root->_rank++;
increment_rank(root->left());
increment_rank(root->right());
}
template <typename T>
typename TreeNode<T>::TreeNodePtr TreeNode<T>::create_from_array(const std::vector<T>& arr)
{
if (arr.size() == 0)
{
return nullptr;
}
TreeNodePtr root = std::make_shared<TreeNode<T>>(arr[0]);
for (int i = 1; i < (int) arr.size(); ++i)
{
const T &value = arr[i];
insert_value(root, value);
}
return root;
}
template <typename T>
void TreeNode<T>::print_in_order(const TreeNodePtr root)
{
// this print visitor will print, in order, the [rank] of a node, and its data
std::function<bool(const TreeNodePtr &node)> print_node =
[](const TreeNodePtr &node) -> bool
{
std::cout << "[" << node->rank() << "]" << node->data() << " ";
return true;
};
std::function<void()> no_more_node =
[]() -> void
{
std::cout << std::endl;
};
traverse_in_order(root, &print_node, &no_more_node);
}
template <typename T>
T TreeNode<T>::find_kth_smallest(const TreeNodePtr root, unsigned int k)
{
if (k == 0)
{
throw std::invalid_argument("Invalid k(0)");
}
if (root == nullptr)
{
throw std::invalid_argument("Root node is null!");
}
TreeNodePtr current = root;
for (;;)
{
if (current == nullptr)
{
// we've exhausted the search space without finding the rank we want;
// this can only mean that k was outside the range of |nodes|
throw std::out_of_range("k is out of range!");
}
if ((current->rank()+1) == k)
{
// found it!
return current->data();
}
else if ((current->rank()+1) < k)
{
// k is somewhere in the right subtree
current = current->right();
}
else
{
// k is somewhere in the left subtree
current = current->left();
}
}
}
template <typename T>
template <typename VisitFunctor, typename EndFunctor>
bool TreeNode<T>::traverse_in_order(const TreeNodePtr root, VisitFunctor *vf, EndFunctor *ef)
{
if (root == nullptr)
{
if (ef != nullptr)
{
(*ef)();
}
return true;
}
std::stack<TreeNodePtr> st; // the stack used to implement this traversal without recursion
// first, find the extreme-left (smallest) node, pushing each node along the way onto the stack
TreeNodePtr current = root;
while (current != nullptr)
{
st.push(current);
current = current->left();
}
for (;;)
{
if (st.empty())
{
// terminating case: we've traversed all nodes; call the end functor callback if supplied, and return
if (ef != nullptr)
{
(*ef)();
}
return true;
}
// pop and "visit" the top of the stack
current = st.top(); st.pop();
if (vf != nullptr)
{
if ((*vf)(current) == false)
{
// the visitor callback wants us to stop traversing abruptly!
return false;
}
}
// now, see if the node we just popped has any right child;
// if so, push it, and its left-only descendants to the stack;
current = current->right();
while (current != nullptr)
{
st.push(current);
current = current->left();
}
// after this (at the start of next loop), the next node we pop will be
// the next smallest node (extreme-left of the current node's right child)
}
}
int main()
{
/**
* The test tree should look like:
*
* 3
* / \
* / \
* 1 5
* \ / \
* 2 4 6
*/
std::vector<int> arr = { 3, 1, 5, 4, 2, 6 };
TreeNode<int>::TreeNodePtr root = TreeNode<int>::create_from_array(arr);
std::cout << "Traverse in-order with [rank]: ";
TreeNode<int>::print_in_order(root);
std::cout << std::endl;
std::cout << "1st smallest node: " << TreeNode<int>::find_kth_smallest(root, 1) << std::endl;
std::cout << "3rd smallest node: " << TreeNode<int>::find_kth_smallest(root, 3) << std::endl;
std::cout << "6th smallest node: " << TreeNode<int>::find_kth_smallest(root, 6) << std::endl;
std::cout << "7th smallest node: ";
try
{
std::cout << TreeNode<int>::find_kth_smallest(root, 7);
}
catch (const std::out_of_range&)
{
std::cout << "out of range!";
}
std::cout << std::endl;
return 0;
}
#include <stack>
#include <vector>
#include <memory>
#include <functional>
#include <exception>
#include <iostream>
#include <cassert>
template <typename T>
class TreeNode
{
public:
using TreeNodePtr = std::shared_ptr<TreeNode<T>>;
TreeNode(T data, TreeNodePtr left = nullptr, TreeNodePtr right = nullptr);
T data() const;
const TreeNodePtr left() const;
const TreeNodePtr right() const;
/**
* Given a root to a binary search tree, insert a new value.
* If successful, returns a pointer to the newly created node.
* Otherwise, nullptr is returned.
*/
static TreeNodePtr insert_value(TreeNodePtr root, T data);
/**
* Given an (unordered) array of values, create a binary search tree and return
* a pointer to its root.
*/
static TreeNodePtr create_from_array(const std::vector<T> &arr);
/**
* Given a root to a binary search tree, perform an in-order traversal
* of all its nodes. Additionally, an optional callback object can be given
* to be notified when each node is visited, and when all nodes have been visited.
*
* Return true if all nodes were visited, false if the callback forced an early stop.
*/
template <typename VisitFunctor, typename EndFunctor>
static bool traverse_in_order(const TreeNodePtr root, VisitFunctor *vf = nullptr, EndFunctor *ef = nullptr);
/**
* Using traverse_in_order() above, visit each node and print its value to stdout.
*/
static void print_in_order(const TreeNodePtr root);
/**
* Given a root to a binary search tree, find the k-th smallest value.
* May throw out_of_range exception if k > number of nodes in the tree.
* e.g. if k == 1, the returned value will be the minimum, extreme-left value.
*/
static T find_kth_smallest(const TreeNodePtr root, unsigned int k);
private:
T _data;
TreeNodePtr _left;
TreeNodePtr _right;
};
template<typename T>
TreeNode<T>::TreeNode(T data, TreeNodePtr left, TreeNodePtr right)
: _data(data)
, _left(left)
, _right(right)
{
}
template<typename T>
T TreeNode<T>::data() const
{
return _data;
}
template<typename T>
const typename TreeNode<T>::TreeNodePtr TreeNode<T>::left() const
{
return _left;
}
template<typename T>
const typename TreeNode<T>::TreeNodePtr TreeNode<T>::right() const
{
return _right;
}
template<typename T>
typename TreeNode<T>::TreeNodePtr TreeNode<T>::insert_value(TreeNodePtr root, T data)
{
if (data == root->data())
{
// duplicate value detected!
return nullptr;
}
else if (data < root->data())
{
if (root->left() == nullptr)
{
root->_left = std::make_shared<TreeNode<T>>(data);
return root->left();
}
else
{
return insert_value(root->left(), data);
}
}
else
{
if (root->right() == nullptr)
{
root->_right = std::make_shared<TreeNode<T>>(data);
return root->right();
}
else
{
return insert_value(root->right(), data);
}
}
}
template<typename T>
typename TreeNode<T>::TreeNodePtr TreeNode<T>::create_from_array(const std::vector<T>& arr)
{
if (arr.size() == 0)
{
return nullptr;
}
TreeNodePtr root = std::make_shared<TreeNode<T>>(arr[0]);
for (int i = 1; i < (int) arr.size(); ++i)
{
const T &value = arr[i];
insert_value(root, value);
}
return root;
}
template<typename T>
void TreeNode<T>::print_in_order(const TreeNodePtr root)
{
std::function<bool(const TreeNodePtr &node)> print_node =
[](const TreeNodePtr &node) -> bool
{
std::cout << node->data() << " ";
return true;
};
std::function<void()> no_more_node =
[]() -> void
{
std::cout << std::endl;
};
traverse_in_order(root, &print_node, &no_more_node);
}
template<typename T>
T TreeNode<T>::find_kth_smallest(const TreeNodePtr root, unsigned int k)
{
if (k == 0)
{
throw std::invalid_argument("Invalid k(0)");
}
unsigned int visited = 0;
T found;
auto visit_node =
[&visited, &found, k](const TreeNodePtr &node) -> bool
{
++visited;
if (visited == k)
{
// found the k-th smallest node!
found = node->data();
return false; // signal the traversal algorithm to stop now
}
return true;
};
auto end_node =
[]() -> void
{
// if we get called at the exhaustion of all nodes, then k was bigger than the number of nodes
throw std::out_of_range("Invalid k(out_of_range)");
};
bool exhausted = traverse_in_order(root, &visit_node, &end_node);
assert (!exhausted); (exhausted);
return found;
}
template<typename T>
template<typename VisitFunctor, typename EndFunctor>
bool TreeNode<T>::traverse_in_order(const TreeNodePtr root, VisitFunctor *vf, EndFunctor *ef)
{
if (root == nullptr)
{
if (ef != nullptr)
{
(*ef)();
}
return true;
}
std::stack<TreeNodePtr> st; // the stack used to implement this traversal without recursion
// first, find the extreme-left (smallest) node, pushing each node along the way onto the stack
TreeNodePtr current = root;
while (current != nullptr)
{
st.push(current);
current = current->left();
}
for (;;)
{
if (st.empty())
{
// terminating case: we've traversed all nodes; call the end functor callback if supplied, and return
if (ef != nullptr)
{
(*ef)();
}
return true;
}
// pop and "visit" the top of the stack
current = st.top(); st.pop();
if (vf != nullptr)
{
if ((*vf)(current) == false)
{
// the visitor callback wants us to stop traversing abruptly!
return false;
}
}
// now, see if the node we just popped has any right child;
// if so, push it, and its left-only descendants to the stack;
current = current->right();
while (current != nullptr)
{
st.push(current);
current = current->left();
}
// after this (at the start of next loop), the next node we pop will be
// the next smallest node (extreme-left of the current node's right child)
}
}
int main()
{
/**
* The test tree should look like:
*
* 3
* / \
* / \
* 1 5
* \ / \
* 2 4 6
*/
std::vector<int> arr = { 3, 1, 5, 4, 2, 6 };
TreeNode<int>::TreeNodePtr root = TreeNode<int>::create_from_array(arr);
std::cout << "Traverse in-order: ";
TreeNode<int>::print_in_order(root);
std::cout << std::endl;
std::cout << "1st smallest node: " << TreeNode<int>::find_kth_smallest(root, 1) << std::endl;
std::cout << "3rd smallest node: " << TreeNode<int>::find_kth_smallest(root, 3) << std::endl;
std::cout << "6th smallest node: " << TreeNode<int>::find_kth_smallest(root, 6) << std::endl;
std::cout << "7th smallest node: ";
try
{
std::cout << TreeNode<int>::find_kth_smallest(root, 7);
}
catch (const std::out_of_range&)
{
std::cout << "out of range!";
}
std::cout << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment