Skip to content

Instantly share code, notes, and snippets.

@krayfaus
Last active February 21, 2022 03:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krayfaus/671c9432777b1d8bfca83f1ca850b7fe to your computer and use it in GitHub Desktop.
Save krayfaus/671c9432777b1d8bfca83f1ca850b7fe to your computer and use it in GitHub Desktop.
S+ Tree
// ----------------------------------------------------------------
// S+ Tree
//
// Created by: Sergey Slotin
// Documented at:
// https://en.algorithmica.org/hpc/data-structures/s-tree
// https://twitter.com/sergey_slotin/status/1494349254730690561
//
// ----------------------------------------------------------------
//
// Rewritten by: Ítalo Cadeu (@krayfaus)
//
// This is an initial attempt to transform the amazing S+ Tree data structure,
// created by Sergey Slotin into a templated header-only C++ library.
//
// I hope the code bellow is usable, I've tried my best to understand the original algorithm,
// but as I've never worked into production it may contain begginer mistakes.
//
// https://gist.github.com/krayfaus/671c9432777b1d8bfca83f1ca850b7fe
// ----------------------------------------------------------------
// As acknowledged by Sergey before 'online compilers' are not the best way
// to properly benchmark memory intensive algorithms,
// so please take the output with a grain of salt.
// If possible compile and run the code on your own machine.
// The code compiles on all major compilers: Clang, GCC and MSVC.
#if defined(_MSC_VER)
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
#define NOMINMAX
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* GCC-compatible compiler, targeting x86/x86-64 */
#include <sys/mman.h>
#include <x86intrin.h>
#else
#error "Unknown or Unsupported Platform."
#endif
#include <cstdlib>
#include <limits>
#include <random>
#include <time.h>
#include <fmt/core.h>
// ----------------------------------------------------------------
// Aliases:
using s32 = signed int;
using u32 = unsigned int;
template <typename T, size_t Size>
using c_array = T[Size];
constexpr auto k_Infinity = std::numeric_limits<int>::max();
// ----------------------------------------------------------------
namespace
{
[[nodiscard]] auto allocate_memory(size_t alignment, size_t size) -> void *
{
void *data = nullptr;
#if defined(_MSC_VER)
data = VirtualAlloc(NULL, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
#elif defined(__linux__)
data = std::aligned_alloc(alignment, size);
madvise(data, size, MADV_HUGEPAGE);
#else
#error "Unknown or Unsupported Platform."
#endif
return data;
}
} // namespace
// ----------------------------------------------------------------
namespace
{
/** Calculate the number of B-element blocks in a layer. */
template <s32 BlockSize>
[[nodiscard]] constexpr auto calculate_block_count(s32 n) -> s32
{
return (n + BlockSize - 1) / BlockSize;
}
/** Calculate the number of keys on the layer previous to one with nth element. */
template <s32 BlockSize>
[[nodiscard]] constexpr auto calculate_previous_layer_key_count(s32 n) -> s32
{
return (calculate_block_count<BlockSize>(n) + BlockSize) / (BlockSize + 1) * BlockSize;
}
/** Calculate the height of a balanced n-key B+ tree. */
template <s32 BlockSize>
[[nodiscard]] constexpr auto calculate_height(s32 n) -> s32
{
return n <= BlockSize ? 1 : calculate_height<BlockSize>(calculate_previous_layer_key_count<BlockSize>(n)) + 1;
}
/** Calculate the offset of the h layer on a B+ tree (0 is the largest). */
template <s32 BlockSize, s32 KeyCount>
[[nodiscard]] constexpr auto calculate_offset(s32 h) -> s32
{
// expect(h >= 0, "Invalid tree height.");
s32 k = 0;
s32 n = KeyCount;
while (h--)
{
k += calculate_block_count<BlockSize>(n) * BlockSize;
n = calculate_previous_layer_key_count<BlockSize>(n);
}
return k;
}
} // namespace
// ----------------------------------------------------------------
namespace
{
using reg = __m256i;
void permute(s32 *node)
{
reg const mask = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
reg *middle = (reg *)(node + 4);
reg x = _mm256_loadu_si256(middle);
x = _mm256_permutevar8x32_epi32(x, mask);
_mm256_storeu_si256(middle, x);
}
auto direct_rank(reg x, s32 *y) -> u32
{
reg a = _mm256_load_si256((reg *)y);
reg b = _mm256_load_si256((reg *)(y + 8));
reg ca = _mm256_cmpgt_epi32(a, x);
reg cb = _mm256_cmpgt_epi32(b, x);
#if defined(_MSC_VER)
auto tca = _mm256_castsi256_ps(ca);
auto tcb = _mm256_castsi256_ps(cb);
s32 mb = _mm256_movemask_ps(tcb);
s32 ma = _mm256_movemask_ps(tca);
#else
s32 mb = _mm256_movemask_ps((__m256)cb);
s32 ma = _mm256_movemask_ps((__m256)ca);
#endif
u32 mask = (1 << 16);
mask |= static_cast<u32>(mb << 8);
mask |= static_cast<u32>(ma);
#if defined(_MSC_VER)
return _tzcnt_u32(mask);
#else
return __tzcnt_u32(mask);
#endif
}
auto permuted_rank(reg x, s32 *y) -> u32
{
reg a = _mm256_load_si256((reg *)y);
reg b = _mm256_load_si256((reg *)(y + 8));
reg ca = _mm256_cmpgt_epi32(a, x);
reg cb = _mm256_cmpgt_epi32(b, x);
reg c = _mm256_packs_epi32(ca, cb);
u32 mask = static_cast<u32>(_mm256_movemask_epi8(c));
#if defined(_MSC_VER)
return _tzcnt_u32(mask);
#else
return __tzcnt_u32(mask);
#endif
}
template <s32 BlockSize, s32 TreeHeight, s32 ElementCount>
auto lower_bound(s32 value, s32 *tree) -> s32
{
reg x = _mm256_set1_epi32(value - 1);
u32 k = 0;
for (s32 h = TreeHeight - 1; h > 0; h--)
{
u32 const i = permuted_rank(x, tree + calculate_offset<BlockSize, ElementCount>(h) + k);
k = k * (BlockSize + 1) + (i << 3);
}
u32 i = direct_rank(x, tree + k);
return tree[k + i];
}
} // namespace
// ----------------------------------------------------------------
namespace aethelwerka
{
template <typename ElementType, s32 ElementCount, s32 BlockSize>
class static_tree
{
public:
// Aliases:
using value_type = ElementType;
using pointer_type = value_type *;
// Constants:
static constexpr auto value_size = sizeof(value_type); // Size in bytes of value_type.
static constexpr s32 block_size = BlockSize; // Cache line.
static constexpr s32 element_count = ElementCount; // Element count on input array, key count on btree.
// Height of a balanced n-key B+ tree.
static constexpr s32 tree_height = calculate_height<block_size>(element_count);
// Tree size is the offset of the (non-existent) layer H ("height").
static constexpr s32 tree_size = calculate_offset<block_size, element_count>(tree_height);
public:
// Empty constructor.
static_tree() noexcept
: tree_data(nullptr)
, is_initialized(false)
{
}
// Initialize tree.
[[nodiscard]] bool initialize(c_array<value_type, element_count> input_data) noexcept
{
// expect(!is_initialized);
// expect(input_data);
// expect(element_count > 0);
s32 const page_alignment = 1 << 21; // Page size in bytes (2MB).
// We can only allocate a whole number of pages.
s32 const page_size = (value_size * tree_size + page_alignment - 1) / page_alignment * page_alignment;
// Allocate memory.
tree_data = (s32*) allocate_memory(page_alignment, page_size);
if (!tree_data)
{
// Couldn't allocate memory.
return false;
}
// Pad the tree with infinities.
for (s32 i = element_count; i < tree_size; ++i)
{
tree_data[i] = k_Infinity;
}
// Copy data from input_data array to tree_data.
memcpy(tree_data, input_data, value_size * element_count);
// Build the internal nodes, layer by layer.
for (s32 h = 1; h < tree_height; ++h)
{
for (s32 i = 0; i < calculate_offset<block_size, element_count>(h + 1) - calculate_offset<block_size, element_count>(h); ++i)
{
s32 k = i / block_size;
s32 const j = i - k * block_size;
k = k * (block_size + 1) + j + 1; // Compare to the right of the key.
for (s32 l = 0; l < h - 1; ++l) // And then always to the left.
{
k *= block_size + 1;
}
// Pad the rest with infinities if the key doesn't exist:
tree_data[calculate_offset<block_size, element_count>(h) + i] =
k * block_size < element_count ? tree_data[k * block_size] : k_Infinity;
}
}
// Permute every tree node for faster query time (trick to avoid permuting avx2 later).
for (s32 i = calculate_offset<block_size, element_count>(1); i < tree_size; i += block_size)
{
permute(tree_data + i);
}
// Tree is properly initialized now, and ready to use.
is_initialized = true;
return true;
}
[[nodiscard]] s32 search(value_type value)
{
return lower_bound<block_size, tree_height, element_count>(value, tree_data);
}
private:
// Pointer to tree data.
pointer_type tree_data;
// Status of the tree data.
bool is_initialized;
};
} // namespace aethelwerka
// ----------------------------------------------------------------
template <s32 ArrayLenght>
auto baseline(s32 value, s32 array[ArrayLenght]) -> s32
{
auto const array_first = array;
auto const array_last = array + ArrayLenght;
return *std::lower_bound(array_first, array_last, value);
}
// ----------------------------------------------------------------
int main(int, char **)
{
using fmt::print;
using aethelwerka::static_tree;
// ----------------------------------------------------------------
using ElementType = int;
constexpr auto ElementSize = sizeof(ElementType);
constexpr auto ElementCount = (1 << 16);
constexpr auto IterationCount = (1 << 22);
// ----------------------------------------------------------------
// Data arrays (they're so big they need to go on the heap):
static c_array<ElementType, ElementCount> input_data;
static c_array<int, IterationCount> check_indices;
// ----------------------------------------------------------------
print("Lenght = {}, Iterations = {}\n", ElementCount, IterationCount);
std::mt19937 rng(0);
// Fill input data.
input_data[0] = k_Infinity;
for (s32 i = 1; i < ElementCount; ++i)
{
input_data[i] = rng() % (1 << 30);
}
// Random indices.
for (s32 i = 0; i < IterationCount; ++i)
{
check_indices[i] = rng() % (1 << 30);
}
// ----------------------------------------------------------------
// TreeBlockSize: 16 elements; sizeof(s32) * 16 = 64 bytes (cache line is typically 64 bytes).
constexpr auto TreeBlockSize = 64 / ElementSize;
auto tree = static_tree<ElementType, ElementCount, TreeBlockSize>{};
if (!tree.initialize(input_data))
{
// Couldn't initialize tree.
return -1;
}
// ----------------------------------------------------------------
// The measurement code bellow is ugly and may not give proper metrics.
// ----------------------------------------------------------------
double x = 0.0;
{
clock_t start = clock();
s32 checksum = 0;
for (s32 i = 0; i < IterationCount; ++i)
{
checksum ^= baseline<ElementCount>(check_indices[i], input_data);
}
double seconds = double(clock() - start) / CLOCKS_PER_SEC;
print("Checksum: {}\n", checksum);
x = 1e9 * seconds / IterationCount;
}
// ----------------------------------------------------------------
double y = 0.0;
{
clock_t start = clock();
s32 checksum = 0;
for (s32 i = 0; i < IterationCount; ++i)
{
checksum ^= tree.search(check_indices[i]);
}
double seconds = double(clock() - start) / CLOCKS_PER_SEC;
print("Checksum: {}\n", checksum);
y = 1e9 * seconds / IterationCount;
}
print("std::lower_bound: {:.2f}\n", x);
print("S+ tree: {:.2f}\n", y);
print("Speedup: {:.2f}\n", x / y);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment