Skip to content

Instantly share code, notes, and snippets.

@shanehou
Created February 18, 2019 14:41
Show Gist options
  • Save shanehou/6527808c2084ecf2e476881148236238 to your computer and use it in GitHub Desktop.
Save shanehou/6527808c2084ecf2e476881148236238 to your computer and use it in GitHub Desktop.
Segment Tree Implementation
#include <iostream>
#include <vector>
using namespace std;
int min(const int a, const int b) {
return a > b ? b : a;
}
int sum(const int a, const int b) {
return a + b;
}
int BuildSegmentTree(const vector<int> &input, int index, int from, int to, vector<int> &tree, int(*func)(const int, const int)) {
if (from == to) {
tree[index] = input[from];
return input[from];
}
int mid = from + (to - from) / 2;
tree[index] = (*func)(BuildSegmentTree(input, 2*index+1, from, mid, tree, func), BuildSegmentTree(input, 2*index+2, mid+1, to, tree, func));
return tree[index];
}
vector<int> BuildSegmentTree(const vector<int> &input, int(*func)(const int, const int)) {
if (input.empty()) return vector<int>();
int height = 1;
size_t size = input.size();
while (size >>= 1 > 0) height++;
int n = 1;
for (int i = 0; i < height; i++) n *= 2;
vector<int> tree;
tree.resize(2 * n - 1);
fill_n(tree.begin(), tree.size(), -1);
(void)BuildSegmentTree(input, 0, 0, input.size()-1, tree, func);
return tree;
}
template<int(*func)(const int, const int)>
class SegmentTree {
public:
SegmentTree(const vector<int> &input) : orig(input) {
int height = 1;
size_t size = orig.size();
while (size >>= 1 > 0) height++;
int n = 1;
for (int i = 0; i < height; i++) n *= 2;
tree.resize(2 * n - 1);
fill_n(tree.begin(), tree.size(), -1);
(void)BuildSegmentTree(0, 0, input.size()-1);
}
const vector<int> &GetTree() {
return tree;
}
int GetRangeValue(int start, int end) {
if (start > end || start < 0 || end >= orig.size()) return -1;
return GetRangeValue(0, start, end, 0, orig.size()-1);
}
private:
vector<int> tree, orig;
int BuildSegmentTree(int index, int from, int to) {
if (from == to) {
tree[index] = orig[from];
return orig[from];
}
int mid = from + (to - from) / 2;
tree[index] = (*func)(BuildSegmentTree(2*index+1, from, mid), BuildSegmentTree(2*index+2, mid+1, to));
return tree[index];
}
int GetRangeValue(int index, int start, int end, int from, int to) {
if (start <= from && end >= to) return tree[index];
if (to < start || from > end) return 0;
int mid = from + (to - from) / 2;
return (*func)(GetRangeValue(2*index+1, start, end, from, mid), GetRangeValue(2*index+2, start, end, mid+1, to));
}
};
int main() {
vector<int> input = {1, 3, 5, 7, 9, 11};
SegmentTree<sum> tree(input);
auto output = tree.GetTree();
for (int i = 0; i < output.size(); i++) {
cout << output[i] << endl;
}
cout << "Sum: " << tree.GetRangeValue(2, 4);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment