class SegmentTree { public: SegmentTree(vector<int>& nums) { n = nums.size(); m_st = vector<int>(3 * n, 0); createTree(0, n - 1, 0, nums); } void update(int i, int add) { updateTree(0, n - 1, i, 0, add); } int query(int lo, int hi) { return query(lo, hi, 0, n - 1, 0); } private: vector<int> m_st; int n; void createTree(int lo, int hi, int i, vector<int>& nums) { if(lo == hi) { m_st[i] = nums[lo]; return; } int mid = lo + (hi - lo) / 2; createTree(lo, mid, 2 * i + 1, nums); createTree(mid + 1, hi, 2 * i + 2, nums); m_st[i] = m_st[2 * i + 1] + m_st[2 * i + 2]; } void updateTree(int lo, int hi, int idx, int i, int add) { if(lo == hi) { m_st[i] += add; return; } int mid = lo + (hi - lo) / 2; if(idx <= mid) updateTree(lo, mid, idx, 2 * i + 1, add); else updateTree(mid + 1, hi, idx, 2 * i + 2, add); m_st[i] = m_st[2 * i + 1] + m_st[2 * i + 2]; } int query(int qlo, int qhi, int lo, int hi, int i) { int mid = lo + (hi - lo) / 2; if(qlo > hi || qhi < lo) return 0; else if(qlo <= lo && qhi >= hi) return m_st[i]; else return query(qlo, qhi, lo, mid, 2 * i + 1) + query(qlo, qhi, mid + 1, hi, 2 * i + 2); } }; class NumMatrix { public: NumMatrix(vector<vector<int>> matrix) { nums = matrix; m = matrix.size(); int n = m? matrix[0].size(): 0; if(!m || !n)return; m_st = vector<SegmentTree*>(3 * m, nullptr); createTree(0, m - 1, 0, matrix); } ~NumMatrix() { for(auto ptr : m_st) { if(ptr) delete ptr; } } void update(int row, int col, int val) { int diff = val - nums[row][col]; update(0, m - 1, row, col, 0, diff); nums[row][col] = val; } int sumRegion(int row1, int col1, int row2, int col2) { return query(row1, row2, col1, col2, 0, m - 1, 0); } private: vector<SegmentTree*> m_st; vector<vector<int>> nums; int m; vector<int> createTree(int lo, int hi, int i, vector<vector<int>>& matrix) { if(lo == hi) { m_st[i] = new SegmentTree(matrix[lo]); return matrix[lo]; } int mid = lo + (hi - lo) / 2; auto v1 = createTree(lo, mid, 2 * i + 1, matrix); auto v2 = createTree(mid + 1, hi, 2 * i + 2, matrix); for(int i = 0; i < v1.size(); ++i) v1[i] += v2[i]; m_st[i] = new SegmentTree(v1); return v1; } void update(int lo, int hi, int i, int j, int idx, int add) { if(lo == hi) { m_st[idx]->update(j, add); return; } int mid = lo + (hi - lo) / 2; if(i <= mid) update(lo, mid, i, j, 2 * idx + 1, add); else update(mid + 1, hi, i, j, 2 * idx + 2, add); m_st[idx]->update(j, add); } int query(int xlo, int xhi, int ylo, int yhi, int lo, int hi, int i) { int mid = lo + (hi - lo) / 2; if(xhi < lo || xlo > hi) return 0; else if(xlo <= lo && xhi >= hi) return m_st[i]->query(ylo, yhi); else return query(xlo, xhi, ylo, yhi, lo, mid, 2 * i + 1) + query(xlo, xhi, ylo, yhi, mid + 1, hi, 2 * i + 2); } }; /** * Your NumMatrix object will be instantiated and called as such: * NumMatrix obj = new NumMatrix(matrix); * obj.update(row,col,val); * int param_2 = obj.sumRegion(row1,col1,row2,col2); */