Created
February 18, 2019 14:41
-
-
Save shanehou/6527808c2084ecf2e476881148236238 to your computer and use it in GitHub Desktop.
Segment Tree Implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
using namespace std; | |
int min(const int a, const int b) { | |
return a > b ? b : a; | |
} | |
int sum(const int a, const int b) { | |
return a + b; | |
} | |
int BuildSegmentTree(const vector<int> &input, int index, int from, int to, vector<int> &tree, int(*func)(const int, const int)) { | |
if (from == to) { | |
tree[index] = input[from]; | |
return input[from]; | |
} | |
int mid = from + (to - from) / 2; | |
tree[index] = (*func)(BuildSegmentTree(input, 2*index+1, from, mid, tree, func), BuildSegmentTree(input, 2*index+2, mid+1, to, tree, func)); | |
return tree[index]; | |
} | |
vector<int> BuildSegmentTree(const vector<int> &input, int(*func)(const int, const int)) { | |
if (input.empty()) return vector<int>(); | |
int height = 1; | |
size_t size = input.size(); | |
while (size >>= 1 > 0) height++; | |
int n = 1; | |
for (int i = 0; i < height; i++) n *= 2; | |
vector<int> tree; | |
tree.resize(2 * n - 1); | |
fill_n(tree.begin(), tree.size(), -1); | |
(void)BuildSegmentTree(input, 0, 0, input.size()-1, tree, func); | |
return tree; | |
} | |
template<int(*func)(const int, const int)> | |
class SegmentTree { | |
public: | |
SegmentTree(const vector<int> &input) : orig(input) { | |
int height = 1; | |
size_t size = orig.size(); | |
while (size >>= 1 > 0) height++; | |
int n = 1; | |
for (int i = 0; i < height; i++) n *= 2; | |
tree.resize(2 * n - 1); | |
fill_n(tree.begin(), tree.size(), -1); | |
(void)BuildSegmentTree(0, 0, input.size()-1); | |
} | |
const vector<int> &GetTree() { | |
return tree; | |
} | |
int GetRangeValue(int start, int end) { | |
if (start > end || start < 0 || end >= orig.size()) return -1; | |
return GetRangeValue(0, start, end, 0, orig.size()-1); | |
} | |
private: | |
vector<int> tree, orig; | |
int BuildSegmentTree(int index, int from, int to) { | |
if (from == to) { | |
tree[index] = orig[from]; | |
return orig[from]; | |
} | |
int mid = from + (to - from) / 2; | |
tree[index] = (*func)(BuildSegmentTree(2*index+1, from, mid), BuildSegmentTree(2*index+2, mid+1, to)); | |
return tree[index]; | |
} | |
int GetRangeValue(int index, int start, int end, int from, int to) { | |
if (start <= from && end >= to) return tree[index]; | |
if (to < start || from > end) return 0; | |
int mid = from + (to - from) / 2; | |
return (*func)(GetRangeValue(2*index+1, start, end, from, mid), GetRangeValue(2*index+2, start, end, mid+1, to)); | |
} | |
}; | |
int main() { | |
vector<int> input = {1, 3, 5, 7, 9, 11}; | |
SegmentTree<sum> tree(input); | |
auto output = tree.GetTree(); | |
for (int i = 0; i < output.size(); i++) { | |
cout << output[i] << endl; | |
} | |
cout << "Sum: " << tree.GetRangeValue(2, 4); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment