Skip to content

Instantly share code, notes, and snippets.

@siddhantkushwaha
Last active October 19, 2019 20:03
Show Gist options
  • Save siddhantkushwaha/6b9eb9eb523d83daf977877832dc4712 to your computer and use it in GitHub Desktop.
Save siddhantkushwaha/6b9eb9eb523d83daf977877832dc4712 to your computer and use it in GitHub Desktop.
Segment Tree - range-query and point-update, range update (lzp)
/*
Bitmasks are cool. A bitmask is a string of binary bits (0s and 1s). For example: "01010" is a bitmask.
Kuldeep is a naughty but brilliant computer scientist. In his free time, he writes some random programs
to play with bitmasks. He has a PhD student under him and to test him (and entertain himself),
he has given him the following task: Given a number N, write a bitmask of length N containing all 0s.
Now, you are given Q operations. Each operation contains two numbers (l, r) as input.
An operation can be one of the following:
Update operation: Take the XOR of all the bits in the bitmask from index l to r (both inclusive) with 1.
Query operation: Count the number of set bits in the bitmask between index l to r (both inclusive).
*/
#include "bits/stdc++.h"
using namespace std;
struct node {
int val;
int lazy;
node *left;
node *right;
node() {
left = nullptr;
right = nullptr;
lazy = 0;
}
};
void build(node *&root, int s, int e) {
root = new node();
if (s == e)
root->val = 0;
else {
int mid = (s + e) / 2;
build(root->left, s, mid);
build(root->right, mid + 1, e);
root->val = root->left->val + root->right->val;
}
}
void change(node *root, int s, int e, int v) {
if (v & 1)
root->val = (e - s + 1) - root->val;
if (s != e) {
root->left->lazy += v;
root->right->lazy += v;
}
return;
}
void update_range(node *root, int s, int e, int l, int r) {
if (root->lazy > 0) {
change(root, s, e, root->lazy);
root->lazy = 0;
}
if (s > r || e < l)
return;
if (s >= l && e <= r) {
change(root, s, e, 1);
return;
}
int mid = (s + e) / 2;
update_range(root->left, s, mid, l, r);
update_range(root->right, mid + 1, e, l, r);
root->val = root->left->val + root->right->val;
}
int query(node *root, int s, int e, int l, int r) {
if (root->lazy > 0) {
change(root, s, e, root->lazy);
root->lazy = 0;
}
if (s > r || e < l)
return 0;
if (s >= l && e <= r)
return root->val;
int mid = (s + e) / 2;
int vl = query(root->left, s, mid, l, r);
int vr = query(root->right, mid + 1, e, l, r);
return vl + vr;
}
int main() {
int n = 5;
node *root = nullptr;
build(root, 0, n - 1);
update_range(root, 0, n - 1, 1, 3);
cout << query(root, 0, n - 1, 1, 2) << '\n';
update_range(root, 0, n - 1, 0, 4);
cout << query(root, 0, n - 1, 3, 4) << '\n';
}
node *persistent_update(node *root, int s, int e, int idx, int val) {
node *updated_node = new node();
if (s == e)
updated_node->val = val;
else {
int mid = (s + e) / 2;
if (idx <= mid) {
updated_node->left = persistent_update(root->left, s, mid, idx, val);
updated_node->right = root->right;
} else {
updated_node->left = root->left;
updated_node->right = persistent_update(root->right, mid + 1, e, idx, val);
}
updated_node->val = updated_node->left->val + updated_node->right->val;
}
return updated_node;
}
#include "bits/stdc++.h"
using namespace std;
struct node {
int val;
node *left;
node *right;
node() {
left = nullptr;
right = nullptr;
}
};
void build(node *&root, int arr[], int s, int e) {
root = new node();
if (s == e)
root->val = arr[s];
else {
int mid = (s + e) / 2;
build(root->left, arr, s, mid);
build(root->right, arr, mid + 1, e);
root->val = root->left->val + root->right->val;
}
}
/* range query O(log(n)) */
int query(node *root, int s, int e, int low, int high) {
if (s > high || e < low)
return 0;
if (s >= low && e <= high)
return root->val;
int mid = (s + e) / 2;
int vl = query(root->left, s, mid, low, high);
int vr = query(root->right, mid + 1, e, low, high);
return vl + vr;
}
/* point update O(log(n)) */
void update(node *root, int s, int e, int idx, int val) {
if (s == e)
root->val = val;
else {
int mid = (s + e) / 2;
if (idx <= mid)
update(root->left, s, mid, idx, val);
else
update(root->right, mid + 1, e, idx, val);
root->val = root->left->val + root->right->val;
}
}
int main() {
int arr[] = {1, 2, 3, 4, 5};
int n = sizeof(arr) / 4;
node *root = nullptr;
build(root, arr, 0, n - 1);
cout << query(root, 0, 4, 1, 2) << "\n";
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment