Created
November 30, 2023 16:04
-
-
Save roman-kashitsyn/8f6c767a63d1c681d87d61d383c180b8 to your computer and use it in GitHub Desktop.
Extensible segment trees
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <assert.h> | |
#include <stdint.h> | |
#include <stdio.h> | |
uint64_t | |
Lsb(const uint64_t n) | |
{ | |
return n - ((n - 1) & n); | |
} | |
// Returns an integer with a single bit set in the position | |
// where the argument (n) has the last zero bit. | |
uint64_t | |
LastZeroBit(const uint64_t n) | |
{ | |
return Lsb(n + 1); | |
} | |
// Round the argument up to the next highest power of two. | |
uint64_t | |
RoundUpPowerOf2(uint64_t n) | |
{ | |
// See Bit Twiddling Hacks. | |
n--; | |
n |= n >> 1; | |
n |= n >> 2; | |
n |= n >> 4; | |
n |= n >> 8; | |
n |= n >> 16; | |
n |= n >> 32; | |
n++; | |
return n; | |
} | |
uint64_t | |
PBT_Root(const uint64_t size) | |
{ | |
assert(size == 1 || RoundUpPowerOf2(size) == size + 1); | |
return size >> 1; | |
} | |
// Computes the parent of node I in a perfect binary tree. | |
uint64_t | |
PBT_Parent(const uint64_t i) | |
{ | |
return (LastZeroBit(i) | i) & ~(LastZeroBit(i) << 1); | |
} | |
// Computes the left child of node P in a perfect binary tree. | |
// Requires: P is not a leaf. | |
uint64_t | |
PBT_LeftChild(const uint64_t p) | |
{ | |
assert(p & 1); | |
return p & ~(LastZeroBit(p) >> 1); | |
} | |
// Computes the right child of node P in a perfect binary tree. | |
// Requires: P is not a leaf. | |
uint64_t | |
PBT_RightChild(const uint64_t p) | |
{ | |
assert(p & 1); | |
return (p | LastZeroBit(p)) & ~(LastZeroBit(p) >> 1); | |
} | |
// Computes the left child in a flat in-order left-perfect binary tree. | |
uint64_t | |
LPBT_LeftChild(const uint64_t i) | |
{ | |
return PBT_LeftChild(i); | |
} | |
// Computes the root of a left-perfect binary tree of the given SIZE. | |
uint64_t | |
LPBT_Root(const uint64_t size) | |
{ | |
return PBT_Root(RoundUpPowerOf2(size + 1) - 1); | |
} | |
uint64_t PBT_LeftMostLeaf(const uint64_t i) | |
{ | |
return i & (i + 1); | |
} | |
uint64_t LPBT_Parent(const uint64_t i, const uint64_t size) | |
{ | |
const uint64_t p = PBT_Parent(i); | |
return (p < size) ? p : PBT_LeftMostLeaf(i) - 1; | |
} | |
// Computes the parent of node I in a left-perfect binary tree of the given | |
// SIZE. | |
uint64_t | |
LPBT_Parent_Iterative(uint64_t i, const uint64_t size) | |
{ | |
do { | |
i = PBT_Parent(i); | |
} while (i >= size); | |
return i; | |
} | |
uint64_t | |
LPBT_RightChild_2(uint64_t n, const uint64_t size) | |
{ | |
assert(n & 1); | |
for (n = PBT_RightChild(n); n >= size; n = PBT_LeftChild(n)) | |
; | |
return n; | |
} | |
uint64_t | |
LPBT_RightChild(const uint64_t n, const uint64_t size) | |
{ | |
assert(n & 1); | |
const uint64_t p = PBT_RightChild(n); | |
uint64_t result; | |
if (p < size) { | |
result = p; | |
} else { | |
result = n + 1 + LPBT_Root(size - n - 1); | |
} | |
assert(result == LPBT_RightChild_2(n, size)); | |
return result; | |
} | |
// Segment tree implementation | |
typedef int64_t value_t; | |
const size_t MAX_NODES = 10000; | |
// Invariant: G_NumNodes <= MAX_NODES. | |
size_t G_NumNodes; | |
value_t G_Nodes[MAX_NODES]; | |
value_t Combine(const value_t x, const value_t y) | |
{ | |
return x + y; | |
} | |
// Updates the sequence to contain the given ITEM at the specified POSITION. | |
void ST_Set(const size_t position, const value_t item) | |
{ | |
assert(G_NumNodes > 0); | |
assert(position <= G_NumNodes / 2); | |
size_t i = position * 2; | |
G_Nodes[i] = item; | |
if (G_NumNodes == 1) | |
return; | |
size_t root = LPBT_Root(G_NumNodes); | |
do { | |
i = LPBT_Parent(i, G_NumNodes); | |
G_Nodes[i] = Combine( | |
G_Nodes[LPBT_LeftChild(i)], | |
G_Nodes[LPBT_RightChild(i, G_NumNodes)]); | |
} while (i != root); | |
} | |
// Appends the given ITEM at the end of the sequence. | |
void ST_Append(const value_t item) | |
{ | |
assert(G_NumNodes + 2 <= MAX_NODES); | |
if (G_NumNodes == 0) { | |
G_Nodes[G_NumNodes++] = item; | |
} else { | |
G_NumNodes += 2; | |
ST_Set(G_NumNodes / 2, item); | |
} | |
} | |
uint64_t | |
MostSignificantBit(uint64_t n) | |
{ | |
uint64_t x = n; | |
x |= (x >> 1); | |
x |= (x >> 2); | |
x |= (x >> 4); | |
x |= (x >> 8); | |
x |= (x >> 16); | |
x |= (x >> 32); | |
return x - (x >> 1); | |
} | |
// Computes the lowest common ancestor of leaves X and Y in a left-perfect | |
// binary tree. | |
uint64_t | |
LPBT_LCALeaves(const uint64_t x, const uint64_t y) | |
{ | |
assert(!(x & 1)); | |
assert(!(y & 1)); | |
if (x == y) | |
return x; | |
const uint64_t d = MostSignificantBit(x ^ y); | |
return (x & ~d) | (d - 1); | |
} | |
// Computes the sum of the sequence items in the index interval [l, r]. | |
value_t | |
ST_Sum(size_t l, size_t r) | |
{ | |
assert(r * 2 <= G_NumNodes); | |
assert(l <= r); | |
uint64_t i = l * 2, j = r * 2; | |
if (i == j) | |
return G_Nodes[i]; | |
const uint64_t lca = LPBT_LCALeaves(i, j); | |
value_t acc = Combine(G_Nodes[i], G_Nodes[j]); | |
// Traverse the tree upwards from the left bound and sum up all | |
// the right subtrees on the way. | |
while (1) { | |
const uint64_t p = LPBT_Parent(i, G_NumNodes); | |
if (p == lca) | |
break; | |
const uint64_t rc = LPBT_RightChild(p, G_NumNodes); | |
if (rc != i) | |
acc = Combine(acc, G_Nodes[rc]); | |
i = p; | |
} | |
while (1) { | |
const uint64_t p = LPBT_Parent(j, G_NumNodes); | |
if (p == lca) | |
break; | |
const uint64_t lc = LPBT_LeftChild(p); | |
if (lc != j) | |
acc = Combine(acc, G_Nodes[lc]); | |
j = p; | |
} | |
return acc; | |
} | |
int main(void) | |
{ | |
size_t failures = 0; | |
for (int64_t bound = 0; bound < 300; bound++) { | |
ST_Append(bound * bound); | |
for (size_t i = 0; i <= bound; i++) { | |
for (size_t j = i; j <= bound; j++) { | |
const int64_t actual = ST_Sum(i, j); | |
int64_t expected = 0; | |
for (int64_t x = i; x <= j; x++) | |
expected += x * x; | |
if (actual != expected) { | |
failures++; | |
printf("FAIL: sum(%zu, %zu): %llu != %llu\n", i, j, actual, expected); | |
} | |
} | |
} | |
} | |
if (!failures) | |
printf("OK\n"); | |
return (failures != 0); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment