Skip to content

Instantly share code, notes, and snippets.

@Jokeren
Created October 27, 2019 23:29
Show Gist options
  • Save Jokeren/d57d28db4497a28c9e48b43183eb6418 to your computer and use it in GitHub Desktop.
Save Jokeren/d57d28db4497a28c9e48b43183eb6418 to your computer and use it in GitHub Desktop.
segment tree 2d
class NumMatrix {
private:
std::vector<std::vector<int> > tree;
void columnInit(std::vector<int> &tree_column, std::vector<int> &matrix_column) {
size_t columns = matrix_column.size();
for (size_t i = columns; i < tree_column.size(); ++i) {
tree_column[i] = matrix_column[i - columns];
}
for (size_t i = columns - 1; i > 0; --i) {
tree_column[i] = tree_column[i * 2] + tree_column[i * 2 + 1];
}
}
void columnUpdate(std::vector<int> &tree_column, int col, int diff) {
size_t pos = col + tree[0].size() / 2;
tree_column[pos] += diff;
while (pos > 1) {
pos /= 2;
tree_column[pos] = tree_column[pos * 2] + tree_column[pos * 2 + 1];
}
}
int columnSumRegion(std::vector<int> &tree_column, int col1, int col2) {
size_t lpos = col1 + tree[0].size() / 2;
size_t rpos = col2 + tree[0].size() / 2;
int sum = 0;
while (lpos <= rpos) {
if ((lpos % 2) == 1) {
sum += tree_column[lpos];
++lpos;
}
if ((rpos % 2) == 0) {
sum += tree_column[rpos];
--rpos;
}
lpos /= 2;
rpos /= 2;
}
return sum;
}
public:
NumMatrix(vector<vector<int>>& matrix) {
if (matrix.size() == 0) {
return;
}
size_t rows = matrix.size();
size_t columns = matrix[0].size();
tree.resize(rows * 2);
for (size_t i = 0; i < tree.size(); ++i) {
tree[i].resize(columns * 2);
}
for (size_t i = rows; i < tree.size(); ++i) {
columnInit(tree[i], matrix[i - rows]);
}
for (size_t i = rows - 1; i > 0; --i) {
for (size_t j = 1; j < columns * 2; ++j) {
tree[i][j] = tree[i * 2][j] + tree[i * 2 + 1][j];
}
}
}
void update(int row, int col, int val) {
// find row first
size_t rows = tree.size() / 2;
size_t columns = tree[0].size() / 2;
size_t pos = row + rows;
int prev = tree[pos][col + columns];
int diff = val - prev;
columnUpdate(tree[pos], col, diff);
// for (size_t i = 0; i < columns * 2; ++i) {
// std::cout << tree[pos][i] << " ";
// }
// std::cout << std::endl;
while (pos > 1) {
pos /= 2;
columnUpdate(tree[pos], col, diff);
// for (size_t i = 0; i < columns * 2; ++i) {
// std::cout << tree[pos][i] << " ";
// }
// std::cout << std::endl;
}
}
int sumRegion(int row1, int col1, int row2, int col2) {
size_t lpos = row1 + tree.size() / 2;
size_t rpos = row2 + tree.size() / 2;
int sum = 0;
while (lpos <= rpos) {
if ((lpos % 2) == 1) {
sum += columnSumRegion(tree[lpos], col1, col2);
++lpos;
}
if ((rpos % 2) == 0) {
sum += columnSumRegion(tree[rpos], col1, col2);
--rpos;
}
lpos /= 2;
rpos /= 2;
}
return sum;
}
};
/**
* Your NumMatrix object will be instantiated and called as such:
* NumMatrix* obj = new NumMatrix(matrix);
* obj->update(row,col,val);
* int param_2 = obj->sumRegion(row1,col1,row2,col2);
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment