Skip to content

Instantly share code, notes, and snippets.

@EvanMu96
Created November 29, 2020 08:09
Show Gist options
  • Save EvanMu96/eb5984e104fcddd3a1f6049d604b86c3 to your computer and use it in GitHub Desktop.
Save EvanMu96/eb5984e104fcddd3a1f6049d604b86c3 to your computer and use it in GitHub Desktop.
307
#include <iostream>
#include <algorithm>
#include <memory>
#include <vector>
using namespace std;
// 一种特殊的二叉搜索树 线段树
// 线段树的树叶节点是真实插入的元素 其余的节点都表示一个区间和
class SegmentTreeNode {
public:
SegmentTreeNode(int start, int end, int sum = 0, SegmentTreeNode* left = nullptr,
SegmentTreeNode* right = nullptr):
start(start), end(end), sum(sum), left(left), right(right) {}
SegmentTreeNode(const SegmentTreeNode&) = delete; // disable copy construction
SegmentTreeNode& operator=(const SegmentTreeNode&) = delete; // disable asssignment
~SegmentTreeNode() {
delete left;
delete right;
left = right = nullptr;
}
int start; // start of index
int end; // end of index (included)
int sum; // sum of the range
SegmentTreeNode* left; // left subTree
SegmentTreeNode* right; // right subTree
};
class NumArray {
public:
NumArray(vector<int>& nums) {
nums_.swap(nums);
if(!nums_.empty()) root_.reset(buildTree(0, nums_.size() - 1));
}
void update(int i, int val) {
updateTree(root_.get(), i, val);
}
int sumRange(int i, int j) {
return sumRange(root_.get(), i, j);
}
private:
vector<int> nums_;
std::unique_ptr<SegmentTreeNode> root_;
// initialize a SegmentTree
SegmentTreeNode* buildTree(int start, int end) {
if(start == end) // only 1 element, this is a leaf node
{
cout << "Build a leaf: " << nums_[start] << endl;
return new SegmentTreeNode(start, end, nums_[start]);
}
int mid = start + (end - start) / 2;
auto left = buildTree(start, mid);
auto right = buildTree(mid + 1, end);
cout << "Build a non-leaf, the sum is:" << left->sum + right->sum << endl;
auto node = new SegmentTreeNode(start, end, left->sum + right->sum, left, right);
return node;
}
void updateTree(SegmentTreeNode* root, int i, int val) {
if(root->start == i && root->end == i) {
root->sum = val;
return;
}
int mid = root->start + (root->end - root->start) / 2;
if(i <= mid) {
updateTree(root->left, i, val);
}
else updateTree(root->right, i, val);
root->sum = root->left->sum + root->right->sum;
}
int sumRange(SegmentTreeNode* root, int i, int j) {
if(i == root->start && j == root->end) return root->sum;
//cout << root->sum << " ";
int mid = root->start + (root->end - root->start) / 2;
if(j <= mid) {
return sumRange(root->left, i, j);
} else if(i > mid) {
return sumRange(root->right, i, j);
}
else return sumRange(root->left, i, mid) + sumRange(root->right, mid + 1, j);
}
};
int main() {
vector<int> v = {1, 3, 5};
NumArray n(v);
cout << n.sumRange(0, 2) << endl;
n.update(1, 2);
cout << n.sumRange(0, 2) << endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment