Skip to content

Instantly share code, notes, and snippets.

Last active December 16, 2022 15:27
Show Gist options
  • Save poseidon4o/6b2a1d077cc67ef5a6090497470e3c38 to your computer and use it in GitHub Desktop.
Save poseidon4o/6b2a1d077cc67ef5a6090497470e3c38 to your computer and use it in GitHub Desktop.
#if _WIN64
#include <cstdlib>
#include <algorithm>
#include <random>
#include <cassert>
#include <cstdio>
#include <cstring>
#if _WIN64
#define bassert(test) ( !!(test) ? (void)0 : ((void)printf("-- Assertion failed at line %d: %s\n", __LINE__, #test), __debugbreak()) )
#define bassert assert
const int NOT_FOUND = -1;
const int NOT_SEARCHED = -2;
// Utility functions not needed for solution
namespace {
#include <stdint.h>
#if __linux__ != 0
#include <time.h>
static uint64_t timer_nsec() {
const clockid_t clockid = CLOCK_MONOTONIC_RAW;
const clockid_t clockid = CLOCK_MONOTONIC;
timespec t;
clock_gettime(clockid, &t);
return t.tv_sec * 1000000000UL + t.tv_nsec;
#elif _WIN64 != 0
#define NOMINMAX
#include <Windows.h>
static struct TimerBase {
TimerBase() {
} timerBase;
// the order of global initialisaitons is non-deterministic, do
// not use this routine in the ctors of globally-scoped objects
static uint64_t timer_nsec() {
return 1000000000ULL * t.QuadPart / timerBase.freq.QuadPart;
#elif __APPLE__ != 0
#include <mach/mach_time.h>
static struct TimerBase {
mach_timebase_info_data_t tb;
TimerBase() {
} timerBase;
// the order of global initialisaitons is non-deterministic, do
// not use this routine in the ctors of globally-scoped objects
static uint64_t timer_nsec() {
const uint64_t t = mach_absolute_time();
return t * timerBase.tb.numer / timerBase.tb.denom;
/// Allocate aligned for @count objects of type T, does not perform initialization
/// @param count - the number of objects
/// @param unaligned [out] - stores the un-aligned pointer, used to call free
/// @return pointer to the memory or nullptr
template <typename T>
T *alignedAlloc(size_t count, void *& unaligned) {
const size_t bytes = count * sizeof(T);
unaligned = malloc(bytes + 63);
if (!unaligned) {
return nullptr;
T* const aligned = reinterpret_cast<T*>(uintptr_t(unaligned) + 63 & -64);
return aligned;
template <typename T>
struct AlignedArrayPtr {
void *allocated = nullptr;
T *aligned = nullptr;
int64_t count = -1;
AlignedArrayPtr() = default;
AlignedArrayPtr(int64_t count) {
void init(int64_t newCount) {
bassert(newCount > 0);
aligned = alignedAlloc<T>(newCount, allocated);
count = newCount;
void memset(int value) {
::memset(aligned, value, sizeof(T) * count);
~AlignedArrayPtr() {
T *get() {
return aligned;
const T *get() const {
return aligned;
operator T *() {
return aligned;
operator const T *() const {
return aligned;
int64_t getCount() const {
return count;
const T *begin() const {
return aligned;
const T *end() const {
return aligned + count;
int operator[](int index) const {
return aligned[index];
int &operator[](int index) {
return aligned[index];
AlignedArrayPtr(const AlignedArrayPtr &) = delete;
AlignedArrayPtr &operator=(const AlignedArrayPtr &) = delete;
typedef AlignedArrayPtr<int> AlignedIntArray;
const char magic[] = ".BSEARCH";
const int magicSize = sizeof(magic) - 1;
bool storeToFile(const AlignedIntArray &hayStack, const AlignedIntArray &needles, const char *name) {
FILE *file = fopen(name, "wb+");
if (!file) {
return false;
const char magic[] = ".BSEARCH";
const int64_t sizes[2] = {hayStack.getCount(), needles.getCount()};
fwrite(magic, 1, magicSize, file);
fwrite(sizes, 1, sizeof(sizes), file);
fwrite(hayStack.get(), sizeof(int), hayStack.getCount(), file);
fwrite(needles.get(), sizeof(int), needles.getCount(), file);
return true;
bool loadFromFile(AlignedIntArray &hayStack, AlignedIntArray &needles, const char *name) {
FILE *file = fopen(name, "rb");
if (!file) {
return false;
char test[magicSize] = {0, };
int64_t sizes[2];
int allOk = true;
allOk &= magicSize == fread(test, 1, magicSize, file);
if (strncmp(magic, test, magicSize)) {
printf("Bad magic constant in file [%s]\n", name);
return false;
allOk &= sizeof(sizes) == fread(sizes, 1, sizeof(sizes), file);
allOk &= hayStack.getCount() == int64_t(fread(hayStack.get(), sizeof(int), hayStack.getCount(), file));
allOk &= needles.getCount() == int64_t(fread(needles.get(), sizeof(int), needles.getCount(), file));
return allOk;
/// Verify if previous search produced correct results
/// @param hayStack - the input data that will be searched in
/// @param needles - the values that will be searched
/// @param indices - the indices of the needles (or -1 if the needle is not found)
/// Return the first index @c where find(@hayStack, @needles[@c]) != @indices[@c], or -1 if all indices are correct
int verify(const AlignedIntArray &hayStack, const AlignedIntArray & needles, const AlignedIntArray &indices) {
for (int c = 0; c < needles.getCount(); c++) {
const int value = needles[c];
const int *pos = std::lower_bound(hayStack.begin(), hayStack.end(), value);
const int idx = std::distance(hayStack.begin(), pos);
if (idx == hayStack.getCount() || hayStack[idx] != value) {
bassert(indices[c] == NOT_FOUND);
if (indices[c] != NOT_FOUND) {
return c;
} else {
bassert(indices[c] == idx);
if (indices[c] != idx) {
return c;
return -1;
/// Stack allocator with predefined max size
/// The total memory is 64 byte aligned, all but the first allocation are not guaranteed to be algigned
/// Can only free all the allocations at once
struct StackAllocator {
StackAllocator(uint8_t *ptr, int bytes)
: totalBytes(bytes)
, data(ptr) {}
/// Allocate memory for @count T objects
/// Does *NOT* call constructors
/// @param count - the number of objects needed
/// @return pointer to the allocated memory or nullptr
template <typename T>
T *alloc(int count) {
const int size = count * sizeof(T);
if (idx + size > totalBytes) {
return nullptr;
uint8_t *start = data + idx;
idx += size;
return reinterpret_cast<T*>(start);
/// De-allocate all the memory previously allocated with @alloc
void freeAll() {
idx = 0;
/// Get the max number of bytes that can be allocated by the allocator
int maxBytes() const {
return totalBytes;
/// Get the free space that can still be allocated, same as maxBytes before any allocations
int freeBytes() const {
return totalBytes - idx;
StackAllocator(const StackAllocator &) = delete;
StackAllocator& operator=(const StackAllocator &) = delete;
const int totalBytes;
int idx = 0;
uint8_t *data = nullptr;
/// Binary search implemented to return same result as std::lower_bound
/// When there are multiple values of the searched, it will return index of the first one
/// When the searched value is not found, it will return -1
/// @param hayStack - the input data that will be searched in
/// @param needles - the values that will be searched
/// @param indices - the indices of the needles (or -1 if the needle is not found)
static void binarySearch(const AlignedIntArray &hayStack, const AlignedIntArray & needles, AlignedIntArray &indices) {
for (int c = 0; c < needles.getCount(); c++) {
const int value = needles[c];
int left = 0;
int count = hayStack.getCount();
while (count > 0) {
const int half = count / 2;
if (hayStack[left + half] < value) {
left = left + half + 1;
count -= half + 1;
} else {
count = half;
if (hayStack[left] == value) {
indices[c] = left;
} else {
indices[c] = -1;
/// Implement some search algorithm and optimize it so it is faster than @binarySearch above
/// When the searched value is not found, it will return -1
/// @param hayStack - the input data that will be searched in
/// @param needles - the values that will be searched
/// @param indices - the indices of the needles (or -1 if the needle is not found)
/// @param allocator - all allocations must be done trough this allocator, it can be queried for max allowed allocation size
static void betterSearch(const AlignedIntArray &hayStack, const AlignedIntArray & needles, AlignedIntArray &indices, StackAllocator &/*allocator*/) {
// TODO: Implement your solution here
for (int q = 0; q < needles.getCount(); q++) {
const int value = needles[q];
bool found = false;
for (int c = 0; c < hayStack.getCount(); c++) {
if (hayStack[c] == value) {
indices[q] = c;
found = true;
if (!found) {
indices[q] = -1;
int main() {
printf("+ Correctness tests ... \n");
const int heapSize = 1 << 13;
const int64_t searches = 400ll * (1 << 26);
// enumerate and run correctness test
int testCaseCount = 0;
for (int r = 0; /*no-op*/; r++) {
AlignedArrayPtr<int> hayStack;
AlignedArrayPtr<int> needles;
char fname[64] = {0,};
snprintf(fname, sizeof(fname), "%d.bsearch", r);
if (!loadFromFile(hayStack, needles, fname)) {
printf("Checking %s... ", fname);
AlignedArrayPtr<int> indices(needles.getCount());
AlignedArrayPtr<uint8_t> heap(heapSize);
StackAllocator allocator(heap, heapSize);
betterSearch(hayStack, needles, indices, allocator);
if (verify(hayStack, needles, indices) != -1) {
printf("Failed to verify base betterSearch!\n");
return -1;
binarySearch(hayStack, needles, indices);
if (verify(hayStack, needles, indices) != -1) {
printf("Failed to verify base binarySearch!\n");
return -1;
printf("+ Speed tests ... \n");
for (int r = 0; r < testCaseCount; r++) {
AlignedArrayPtr<int> hayStack;
AlignedArrayPtr<int> needles;
char fname[64] = {0,};
snprintf(fname, sizeof(fname), "%d.bsearch", r);
if (!loadFromFile(hayStack, needles, fname)) {
printf("Failed to load %s for speed test, continuing\n", fname);
const int testRepeat = std::min<int64_t>(1000ll, searches / hayStack.getCount());
printf("Running speed test for %s, %d repeats \n", fname, testRepeat);
AlignedArrayPtr<int> indices(needles.getCount());
AlignedArrayPtr<uint8_t> heap(heapSize);
StackAllocator allocator(heap, heapSize);
uint64_t t0;
uint64_t t1;
// Time the binary search and take average of the runs
t0 = timer_nsec();
for (int test = 0; test < testRepeat; ++test) {
binarySearch(hayStack, needles, indices);
t1 = timer_nsec();
const double totalBinary = (double(t1 - t0) * 1e-9) / testRepeat;
printf("\tbinarySearch time %f\n", totalBinary);
// Time the better search and take average of the runs
t0 = timer_nsec();
for (int test = 0; test < testRepeat; ++test) {
betterSearch(hayStack, needles, indices, allocator);
t1 = timer_nsec();
const double totalBetter = (double(t1 - t0) * 1e-9) / testRepeat;
printf("\tbetterSearch time %f\n", totalBetter);
if (totalBetter < totalBinary) {
printf("Great success!\n");
return 0;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment