Skip to content

Instantly share code, notes, and snippets.

@zhenghaoz
Last active February 24, 2018 07:41
Show Gist options
  • Save zhenghaoz/1743e089bfa3b8ba3074bf45d7cdda19 to your computer and use it in GitHub Desktop.
Save zhenghaoz/1743e089bfa3b8ba3074bf45d7cdda19 to your computer and use it in GitHub Desktop.
B Tree
#include <memory>
#include <vector>
#include <iostream>
template <unsigned N, typename Key, typename Value>
class BTree
{
template <typename T> using vector = std::vector<T>;
template <typename T> using shared_ptr = std::shared_ptr<T>;
struct Node
{
bool _leaf;
int _size;
vector<Key> _keys = vector<Key>(2*N-1);
vector<Value> _values = vector<Value>(2*N-1);
vector<shared_ptr<Node>> _children = vector<shared_ptr<Node>>(2*N);
Node() = default;
Node(const Node &node): _leaf(node._leaf), _size(node._size), _keys(node._keys), _values(node._values)
{
if (!_leaf)
for (int i = 0; i <= _size; i++)
_children[i] = std::make_shared<Node>(*node._children[i]);
}
};
shared_ptr<Node> root;
// find k-v in a node
Value* find(shared_ptr<Node> node, Key key)
{
// search key in node
int i = 0;
while (i < node->_size && key > node->_keys[i])
i++;
if (i < node->_size && key == node->_keys[i])
return &node->_values[i];
else if (node->_leaf)
return nullptr;
else return find(node->_children[i], key);
}
// split a full node (child.size == 2*N-1)
void split(shared_ptr<Node> parent, int i, shared_ptr<Node> child)
{
shared_ptr<Node> nchild = std::make_shared<Node>();
nchild->_leaf = child->_leaf;
nchild->_size = child->_size = N-1;
// move k-v
for (int j = 0; j < N-1; j++) {
nchild->_keys[j] = child->_keys[j + N];
nchild->_values[j] = child->_values[j + N];
}
// move children
if (!child->_leaf)
for (int j = 0; j < N; j++)
nchild->_children[j] = child->_children[j + N];
// move child->key[N-1] up
for (int j = parent->_size; j > i; j--) {
parent->_keys[j] = parent->_keys[j-1];
parent->_values[j] = parent->_values[j-1];
parent->_children[j+1] = parent->_children[j];
}
parent->_keys[i] = child->_keys[N-1];
parent->_values[i] = child->_values[N-1];
parent->_children[i+1] = nchild;
parent->_size++;
}
// insert k-v in a node
void insert(shared_ptr<Node> node, Key key, Value value)
{
// find insert position
int i = 0;
while (i < node->_size && key > node->_keys[i])
i++;
if (node->_leaf) { // insert k-v in a leaf
for (int j = node->_size; j > i; j--) {
node->_keys[j] = node->_keys[j-1];
node->_values[j] = node->_values[j-1];
}
node->_keys[i] = key;
node->_values[i] = value;
node->_size++;
} else { // insert k-v in subNode
shared_ptr<Node> ptr = node->_children[i];
if (ptr->_size == 2*N-1) {
split(node, i, ptr);
if (key > node->_keys[i])
i++;
}
insert(node->_children[i], key, value);
}
}
// insert k-v in root node
void insert(Key key, Value value)
{
shared_ptr<Node> ptr = root;
if (ptr->_size == 2*N-1) { // split root node
root = std::make_shared<Node>();
root->_leaf = false;
root->_size = 0;
root->_children[0] = ptr;
split(root, 0, ptr);
insert(root, key, value);
} else insert(root, key, value);
}
// combine children[i] and children[i+1]
void combine(shared_ptr<Node> parent, int i)
{
shared_ptr<Node> prev = parent->_children[i];
shared_ptr<Node> next = parent->_children[i+1];
// move parent->key[i] down
prev->_keys[prev->_size] = parent->_keys[i];
prev->_values[prev->_size] = parent->_values[i];
prev->_size++;
// move k-v from next to prev
for (int j = 0; j < next->_size; j++) {
prev->_keys[j + prev->_size] = next->_keys[j];
prev->_values[j + prev->_size] = next->_values[j];
}
if (!prev->_leaf)
for (int j = 0; j <= next->_size; j++)
prev->_children[j + prev->_size] = next->_children[j];
prev->_size += next->_size;
// remove parent->key[i]
parent->_size--;
for (int j = i; j < parent->_size; j++) {
parent->_keys[j] = parent->_keys[j+1];
parent->_values[j] = parent->_values[j+1];
parent->_children[j+1] = parent->_children[j+2];
}
}
shared_ptr<Node> max(shared_ptr<Node> node)
{
shared_ptr<Node> ptr = node;
while (!ptr->_leaf)
ptr = ptr->_children[ptr->_size];
return ptr;
}
shared_ptr<Node> min(shared_ptr<Node> node)
{
shared_ptr<Node> ptr = node;
while (!ptr->_leaf)
ptr = ptr->_children[0];
return ptr;
}
// remove key from node, key must be in node
void remove(shared_ptr<Node> node, Key key)
{
// find delete position
int i = 0;
while (i < node->_size && key > node->_keys[i])
i++;
if (node->_leaf) { // case 1: remove k-v from leaf
node->_size--;
for (int j = i; j < node->_size; j++) {
node->_keys[j] = node->_keys[j+1];
node->_values[j] = node->_values[j+1];
}
} else if (i < node->_size && key == node->_keys[i]) { // case 2: find key in internal node
shared_ptr<Node> prevChild = node->_children[i];
shared_ptr<Node> nextChild = node->_children[i+1];
if (prevChild->_size >= N) { // case 2a: move precursor to the position of key
shared_ptr<Node> maxNode = max(prevChild);
node->_keys[i] = maxNode->_keys[maxNode->_size-1];
node->_values[i] = maxNode->_values[maxNode->_size-1];
remove(prevChild, maxNode->_keys[maxNode->_size-1]);
} else if (nextChild->_size >= N) { // case 2b: move successor to the position of key
shared_ptr<Node> minNode = min(nextChild);
node->_keys[i] = minNode->_keys[0];
node->_values[i] = minNode->_values[0];
remove(nextChild, minNode->_keys[0]);
} else { // case 2c: combine previous child and next child
combine(node, i);
remove(node->_children[i], key);
}
} else { // case 3
shared_ptr<Node> subNode = node->_children[i];
if (subNode->_size < N) {
shared_ptr<Node> prevBrother, nextBrother;
if (i > 0) prevBrother = node->_children[i-1];
if (i < node->_size) nextBrother = node->_children[i+1];
if (prevBrother && prevBrother->_size >= N) { // case 3a
// remove node->key[i] into subNode
for (int j = subNode->_size; j > 0; j--) {
subNode->_keys[j] = subNode->_keys[j-1];
subNode->_values[j] = subNode->_values[j-1];
}
if (!subNode->_leaf)
for (int j = subNode->_size; j >= 0; j--)
subNode->_children[j+1] = subNode->_children[j];
subNode->_keys[0] = node->_keys[i-1];
subNode->_values[0] = node->_values[i-1];
subNode->_children[0] = prevBrother->_children[prevBrother->_size];
subNode->_size++;
// remove prevBrother->key[prevBrother->size-1] into node
node->_keys[i-1] = prevBrother->_keys[prevBrother->_size-1];
node->_values[i-1] = prevBrother->_values[prevBrother->_size-1];
prevBrother->_size--;
} else if (nextBrother && nextBrother->_size >= N) { // case 3a
// remove node->key[i] into subNode
subNode->_keys[subNode->_size] = node->_keys[i];
subNode->_values[subNode->_size] = node->_values[i];
subNode->_children[subNode->_size+1] = nextBrother->_children[0];
subNode->_size++;
// remove nextBrother->key[0] into node
node->_keys[i] = nextBrother->_keys[0];
node->_values[i] = nextBrother->_values[0];
nextBrother->_size--;
for (int j = 0; j < nextBrother->_size; j++) {
nextBrother->_keys[j] = nextBrother->_keys[j+1];
nextBrother->_values[j] = nextBrother->_values[j+1];
}
if (!nextBrother->_leaf)
for (int j = 0; j <= nextBrother->_size; j++)
nextBrother->_children[j] = nextBrother->_children[j+1];
} else if (nextBrother) { // case 3b: combine child[i] and child[i+1]
combine(node, i);
} else { // case 3b: combine child[i-1] and child[i]
i--;
combine(node, i);
}
}
remove(node->_children[i], key);
}
}
void print(shared_ptr<Node> node, int level)
{
for (int i = 0; i < level; i++)
std::cout << ' ';
std::cout << "{";
for (int i = 0; i < node->_size; i++) {
if (i) std::cout << ",";
std::cout << node->_keys[i] << ":" << node->_values[i];
}
std::cout << "}" << std::endl;
if (!node->_leaf)
for (int i = 0; i <= node->_size; i++)
print(node->_children[i], level+1);
}
public:
BTree()
{
root = std::make_shared<Node>();
root->_leaf = true;
root->_size = 0;
}
BTree(const BTree &tree)
{
root = std::make_shared<Node>(*tree.root);
}
BTree& operator=(BTree tree)
{
std::swap(root, tree.root);
}
Value* get(const Key &key)
{
return find(root, key);
}
void put(const Key &key, const Value &value)
{
Value *val = find(root, key);
if (val)
*val = value;
else
insert(key, value);
}
void remove(const Key &key)
{
if (find(root, key))
remove(root, key);
if (root->_size == 0)
root = root->_children[0];
if (root == nullptr) {
root = std::make_shared<Node>();
root->_leaf = true;
root->_size = 0;
}
}
void print()
{
print(root, 0);
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment