Skip to content

Instantly share code, notes, and snippets.

@kartikkukreja
Last active December 13, 2022 12:02
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save kartikkukreja/2e7685e1fc8dbca0001b to your computer and use it in GitHub Desktop.
Save kartikkukreja/2e7685e1fc8dbca0001b to your computer and use it in GitHub Desktop.
Segment tree template
// T is the type of input array elements
// V is the type of required aggregate statistic
template<class T, class V>
class SegmentTree {
SegmentTreeNode* nodes;
int N;
public:
SegmentTree(T arr[], int N) {
this->N = N;
nodes = new SegmentTreeNode[getSegmentTreeSize(N)];
buildTree(arr, 1, 0, N-1);
}
~SegmentTree() {
delete[] nodes;
}
V getValue(int lo, int hi) {
SegmentTreeNode result = getValue(1, 0, N-1, lo, hi);
return result.getValue();
}
void update(int index, T value) {
update(1, 0, N-1, index, value);
}
private:
void buildTree(T arr[], int stIndex, int lo, int hi) {
if (lo == hi) {
nodes[stIndex].assignLeaf(arr[lo]);
return;
}
int left = 2 * stIndex, right = left + 1, mid = (lo + hi) / 2;
buildTree(arr, left, lo, mid);
buildTree(arr, right, mid + 1, hi);
nodes[stIndex].merge(nodes[left], nodes[right]);
}
SegmentTreeNode getValue(int stIndex, int left, int right, int lo, int hi) {
if (left == lo && right == hi)
return nodes[stIndex];
int mid = (left + right) / 2;
if (lo > mid)
return getValue(2*stIndex+1, mid+1, right, lo, hi);
if (hi <= mid)
return getValue(2*stIndex, left, mid, lo, hi);
SegmentTreeNode leftResult = getValue(2*stIndex, left, mid, lo, mid);
SegmentTreeNode rightResult = getValue(2*stIndex+1, mid+1, right, mid+1, hi);
SegmentTreeNode result;
result.merge(leftResult, rightResult);
return result;
}
int getSegmentTreeSize(int N) {
int size = 1;
for (; size < N; size <<= 1);
return size << 1;
}
void update(int stIndex, int lo, int hi, int index, T value) {
if (lo == hi) {
nodes[stIndex].assignLeaf(value);
return;
}
int left = 2 * stIndex, right = left + 1, mid = (lo + hi) / 2;
if (index <= mid)
update(left, lo, mid, index, value);
else
update(right, mid+1, hi, index, value);
nodes[stIndex].merge(nodes[left], nodes[right]);
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment