Skip to content

Instantly share code, notes, and snippets.

@roman-kashitsyn
Created November 30, 2023 16:04
Show Gist options
  • Save roman-kashitsyn/8f6c767a63d1c681d87d61d383c180b8 to your computer and use it in GitHub Desktop.
Save roman-kashitsyn/8f6c767a63d1c681d87d61d383c180b8 to your computer and use it in GitHub Desktop.
Extensible segment trees
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
uint64_t
Lsb(const uint64_t n)
{
return n - ((n - 1) & n);
}
// Returns an integer with a single bit set in the position
// where the argument (n) has the last zero bit.
uint64_t
LastZeroBit(const uint64_t n)
{
return Lsb(n + 1);
}
// Round the argument up to the next highest power of two.
uint64_t
RoundUpPowerOf2(uint64_t n)
{
// See Bit Twiddling Hacks.
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
n |= n >> 32;
n++;
return n;
}
uint64_t
PBT_Root(const uint64_t size)
{
assert(size == 1 || RoundUpPowerOf2(size) == size + 1);
return size >> 1;
}
// Computes the parent of node I in a perfect binary tree.
uint64_t
PBT_Parent(const uint64_t i)
{
return (LastZeroBit(i) | i) & ~(LastZeroBit(i) << 1);
}
// Computes the left child of node P in a perfect binary tree.
// Requires: P is not a leaf.
uint64_t
PBT_LeftChild(const uint64_t p)
{
assert(p & 1);
return p & ~(LastZeroBit(p) >> 1);
}
// Computes the right child of node P in a perfect binary tree.
// Requires: P is not a leaf.
uint64_t
PBT_RightChild(const uint64_t p)
{
assert(p & 1);
return (p | LastZeroBit(p)) & ~(LastZeroBit(p) >> 1);
}
// Computes the left child in a flat in-order left-perfect binary tree.
uint64_t
LPBT_LeftChild(const uint64_t i)
{
return PBT_LeftChild(i);
}
// Computes the root of a left-perfect binary tree of the given SIZE.
uint64_t
LPBT_Root(const uint64_t size)
{
return PBT_Root(RoundUpPowerOf2(size + 1) - 1);
}
uint64_t PBT_LeftMostLeaf(const uint64_t i)
{
return i & (i + 1);
}
uint64_t LPBT_Parent(const uint64_t i, const uint64_t size)
{
const uint64_t p = PBT_Parent(i);
return (p < size) ? p : PBT_LeftMostLeaf(i) - 1;
}
// Computes the parent of node I in a left-perfect binary tree of the given
// SIZE.
uint64_t
LPBT_Parent_Iterative(uint64_t i, const uint64_t size)
{
do {
i = PBT_Parent(i);
} while (i >= size);
return i;
}
uint64_t
LPBT_RightChild_2(uint64_t n, const uint64_t size)
{
assert(n & 1);
for (n = PBT_RightChild(n); n >= size; n = PBT_LeftChild(n))
;
return n;
}
uint64_t
LPBT_RightChild(const uint64_t n, const uint64_t size)
{
assert(n & 1);
const uint64_t p = PBT_RightChild(n);
uint64_t result;
if (p < size) {
result = p;
} else {
result = n + 1 + LPBT_Root(size - n - 1);
}
assert(result == LPBT_RightChild_2(n, size));
return result;
}
// Segment tree implementation
typedef int64_t value_t;
const size_t MAX_NODES = 10000;
// Invariant: G_NumNodes <= MAX_NODES.
size_t G_NumNodes;
value_t G_Nodes[MAX_NODES];
value_t Combine(const value_t x, const value_t y)
{
return x + y;
}
// Updates the sequence to contain the given ITEM at the specified POSITION.
void ST_Set(const size_t position, const value_t item)
{
assert(G_NumNodes > 0);
assert(position <= G_NumNodes / 2);
size_t i = position * 2;
G_Nodes[i] = item;
if (G_NumNodes == 1)
return;
size_t root = LPBT_Root(G_NumNodes);
do {
i = LPBT_Parent(i, G_NumNodes);
G_Nodes[i] = Combine(
G_Nodes[LPBT_LeftChild(i)],
G_Nodes[LPBT_RightChild(i, G_NumNodes)]);
} while (i != root);
}
// Appends the given ITEM at the end of the sequence.
void ST_Append(const value_t item)
{
assert(G_NumNodes + 2 <= MAX_NODES);
if (G_NumNodes == 0) {
G_Nodes[G_NumNodes++] = item;
} else {
G_NumNodes += 2;
ST_Set(G_NumNodes / 2, item);
}
}
uint64_t
MostSignificantBit(uint64_t n)
{
uint64_t x = n;
x |= (x >> 1);
x |= (x >> 2);
x |= (x >> 4);
x |= (x >> 8);
x |= (x >> 16);
x |= (x >> 32);
return x - (x >> 1);
}
// Computes the lowest common ancestor of leaves X and Y in a left-perfect
// binary tree.
uint64_t
LPBT_LCALeaves(const uint64_t x, const uint64_t y)
{
assert(!(x & 1));
assert(!(y & 1));
if (x == y)
return x;
const uint64_t d = MostSignificantBit(x ^ y);
return (x & ~d) | (d - 1);
}
// Computes the sum of the sequence items in the index interval [l, r].
value_t
ST_Sum(size_t l, size_t r)
{
assert(r * 2 <= G_NumNodes);
assert(l <= r);
uint64_t i = l * 2, j = r * 2;
if (i == j)
return G_Nodes[i];
const uint64_t lca = LPBT_LCALeaves(i, j);
value_t acc = Combine(G_Nodes[i], G_Nodes[j]);
// Traverse the tree upwards from the left bound and sum up all
// the right subtrees on the way.
while (1) {
const uint64_t p = LPBT_Parent(i, G_NumNodes);
if (p == lca)
break;
const uint64_t rc = LPBT_RightChild(p, G_NumNodes);
if (rc != i)
acc = Combine(acc, G_Nodes[rc]);
i = p;
}
while (1) {
const uint64_t p = LPBT_Parent(j, G_NumNodes);
if (p == lca)
break;
const uint64_t lc = LPBT_LeftChild(p);
if (lc != j)
acc = Combine(acc, G_Nodes[lc]);
j = p;
}
return acc;
}
int main(void)
{
size_t failures = 0;
for (int64_t bound = 0; bound < 300; bound++) {
ST_Append(bound * bound);
for (size_t i = 0; i <= bound; i++) {
for (size_t j = i; j <= bound; j++) {
const int64_t actual = ST_Sum(i, j);
int64_t expected = 0;
for (int64_t x = i; x <= j; x++)
expected += x * x;
if (actual != expected) {
failures++;
printf("FAIL: sum(%zu, %zu): %llu != %llu\n", i, j, actual, expected);
}
}
}
}
if (!failures)
printf("OK\n");
return (failures != 0);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment