Skip to content

Instantly share code, notes, and snippets.

@notnullnotvoid
Last active June 17, 2024 07:32
Show Gist options
  • Save notnullnotvoid/5d60028f92ad1ced5742ab1c41e15d39 to your computer and use it in GitHub Desktop.
Save notnullnotvoid/5d60028f92ad1ced5742ab1c41e15d39 to your computer and use it in GitHub Desktop.
Benchmarking digit reversal algorithms
//sample output:
//[benchmark_func] modulo branchless: 1000000 iters in 0.045876s min, 0.048180s max, 0.046400s avg
//[benchmark_func] modulo basic: 1000000 iters in 0.025224s min, 0.025544s max, 0.025335s avg
//[benchmark_func] manual print: 1000000 iters in 0.043660s min, 0.044033s max, 0.043822s avg
//[benchmark_func] modulo lookup: 1000000 iters in 0.109280s min, 0.110357s max, 0.109678s avg
//[benchmark_func] modulo multiply: 1000000 iters in 0.109537s min, 0.110182s max, 0.109707s avg
//[benchmark_func] stack print: 1000000 iters in 0.276515s min, 0.280833s max, 0.278486s avg
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <math.h>
#include <stdint.h>
#include <chrono>
#define ARRAY_LEN(x) (sizeof(x) / sizeof((x)[0]))
#define print_log(fmt, ...) do { printf("[%s] " fmt, __func__, ##__VA_ARGS__); fflush(stdout); } while (false)
static __attribute__((always_inline)) int bit_log(int i) { return 32 - __builtin_clz(i); }
static __attribute__((always_inline)) int32_t modulo_branchless(int32_t value) {
int64_t sign = value < 0? -1 : 1;
uint64_t absolute = value * sign;
//http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog10
static const int powersOf10[] = { 1, 10, 100, 1'000, 10'000, 100'000, 1'000'000, 10'000'000, 100'000'000, 1'000'000'000 };
int t = (bit_log(absolute) + 1) * 1233 >> 12;
int numDigits = t - (absolute < powersOf10[t]);
uint64_t sum = 0;
// sum += 1 * (absolute / 1000000000 % 10);
sum += absolute / 1000000000; //simplification actually matters here because clang doesn't realize the % 10 for this line is a no-op
sum += 10 * (absolute / 100000000 % 10);
sum += 100 * (absolute / 10000000 % 10);
sum += 1000 * (absolute / 1000000 % 10);
sum += 10000 * (absolute / 100000 % 10);
sum += 100000 * (absolute / 10000 % 10);
sum += 1000000 * (absolute / 1000 % 10);
sum += 10000000 * (absolute / 100 % 10);
sum += 100000000 * (absolute / 10 % 10);
sum += 1000000000 * (absolute / 1 % 10); //no simplification - clang knows this / 1 is a no-op
sum *= powersOf10[numDigits];
//multiplication followed by division-by-constant is better than division by non-constant
//because int mul is fast-ish and int div is really slow!
sum /= 1'000'000'000;
return sum > INT32_MAX? 0 : sum * sign;
}
static __attribute__((always_inline)) int32_t modulo_basic(int32_t value) {
int64_t sign = value < 0? -1 : 1;
int32_t absolute;
//for some reason using this intrinsic here makes this funcion (but not the others!) faster in optimized builds
//even though it should have the same behavior either way because we're passing `-fwrapv`. weird!
__builtin_smul_overflow(value, sign, &absolute);
int64_t sum = 0;
while (absolute > 0) {
sum = sum * 10 + absolute % 10;
absolute /= 10;
}
return sum > INT32_MAX? 0 : sum * sign;
}
static __attribute__((always_inline)) int32_t manual_print(int32_t value) {
int32_t sign = value < 0? -1 : 1;
int32_t absolute;
__builtin_smul_overflow(value, sign, &absolute);
uint8_t digits[16];
int numDigits = 0;
while (absolute) {
digits[numDigits++] = absolute % 10;
absolute /= 10;
}
int64_t sum = 0;
for (int i = 0; i < numDigits; ++i) {
sum = sum * 10 + digits[i];
}
return sum > INT32_MAX? 0 : sum * sign;
}
static __attribute__((always_inline)) int32_t reverseDigits_ModuloLookup(int32_t value) {
constexpr uint64_t tensLookupTable[] = {
1, 10, 100, 1'000, 10'000, 100'000, 1'000'000, 10'000'000, 100'000'000, 1'000'000'000
};
constexpr size_t tensLookupCount = ARRAY_LEN(tensLookupTable);
if (value < 10 && value > -10) return value;
const bool negate = value < 0;
// Store the value in uint64 to handle overflow without branching in the main loop
const uint64_t sourceValue = static_cast<uint64_t>(negate ? -static_cast<int64_t>(value) : static_cast<int64_t>(value));
// Should never be less than 10 given the above early return
size_t largestIndex = 1;
while (largestIndex < tensLookupCount && sourceValue >= tensLookupTable[largestIndex]) ++largestIndex;
// Will always overshoot by 1
--largestIndex;
// If a power of 10, will always result in 1
if (sourceValue == tensLookupTable[largestIndex]) return negate ? -1 : 1;
uint64_t result = 0;
const size_t halfIndex = (largestIndex + 1) / 2;
for (size_t index = 0; index < halfIndex; ++index) {
const size_t upperIndex = largestIndex - index;
const uint64_t lowerTens = tensLookupTable[index];
const uint64_t upperTens = tensLookupTable[upperIndex];
const uint64_t lower = (sourceValue / lowerTens) % 10;
const uint64_t upper = (sourceValue / upperTens) % 10;
result += (lower * upperTens) + (upper * lowerTens);
}
// For an odd number of digits (even index due to 0-based indexing), copy the middle digit over
if ((largestIndex & 1) == 0) {
const uint64_t tens = tensLookupTable[halfIndex];
result += ((sourceValue / tens) % 10) * tens;
}
if (result > INT32_MAX) return 0;
return negate ? -static_cast<int32_t>(result) : static_cast<int32_t>(result);
}
static __attribute__((always_inline)) int32_t reverseDigits_ModuloMultiply(int32_t value) {
if (value < 10 && value > -10) return value;
const bool negate = value < 0;
const uint64_t sourceValue = static_cast<uint64_t>(negate ? -static_cast<int64_t>(value) : static_cast<int64_t>(value));
// Should never drop below 10 given the early return at the top
uint64_t upperTens = 10;
while (sourceValue >= upperTens) upperTens *= 10;
// Will overshoot by one
upperTens /= 10;
// If a power of 10, will always result in 1
if (sourceValue == upperTens) return negate ? -1 : 1;
uint64_t result = 0;
uint64_t lowerTens = 1;
for (; lowerTens < upperTens; lowerTens *= 10, upperTens /= 10) {
const uint64_t lower = (sourceValue / lowerTens) % 10;
const uint64_t upper = (sourceValue / upperTens) % 10;
result += (lower * upperTens) + (upper * lowerTens);
}
// The above loop will end if lowerTens == upperTens; we don't want to treat that the same as swapping the digits
if (lowerTens == upperTens) result += ((sourceValue / lowerTens) % 10) * lowerTens;
if (result > INT32_MAX) return 0;
return negate ? -static_cast<int32_t>(result) : static_cast<int32_t>(result);
}
//snprintf and atoll seem to be slower than the new C++ equivalents but I'm not on a C++20 compiler at the moment -m
static __attribute__((always_inline)) int reverseDigits_CharArrayStack(int32_t value) noexcept {
if (value < 10 && value > -10) return value;
// Don't do size+1 as we don't care about the null terminator; format_to doesn't add it, and we process everything in ranges
char buffer[16];
int sign = value < 0? -1 : 1;
//NOTE: INT_MIN * -1 == 0, but that's fine because it's what we'd get anyway if we printed with extended precision,
// because the digit reversal gives an out-of-range value, for which we return 0
int absolute;
__builtin_smul_overflow(value, sign, &absolute);
int count = snprintf(buffer, sizeof(buffer), "%d", absolute);
const size_t halfCount = count / 2;
for (size_t index = 0; index < halfCount; ++index) {
// swap(buffer[index], buffer[count - index - 1]);
char tmp = buffer[index];
buffer[index] = buffer[count - index - 1];
buffer[count - index - 1] = tmp;
}
int64_t result = atoll(buffer);
//NOTE: it's not possible for the result to be INT32_MIN because its digit-reversed version is way out of range,
// so we don't need to deal with the off-by-one discrepancy between `INT32_MIN` and `-INT32_MAX`
return result > INT32_MAX? 0 : result * sign;
}
static void validate_different_outputs(int32_t value) {
const int32_t branchlessResult = modulo_branchless(value);
const int32_t basicResult = modulo_basic(value);
const int32_t manualPrintResult = manual_print(value);
const int32_t moduloLookupResult = reverseDigits_ModuloLookup(value);
const int32_t moduloMultiplyResult = reverseDigits_ModuloMultiply(value);
const int32_t charStackResult = reverseDigits_CharArrayStack(value);
print_log("[branchless ] Inverting %d = %d\n", value, branchlessResult);
print_log("[basic ] Inverting %d = %d\n", value, basicResult);
print_log("[manual print ] Inverting %d = %d\n", value, manualPrintResult);
print_log("[Modulo Lookup ] Inverting %d = %d\n", value, moduloLookupResult);
print_log("[Modulo Multiply] Inverting %d = %d\n", value, moduloMultiplyResult);
print_log("[Char Stack ] Inverting %d = %d\n", value, charStackResult);
//since we don't have an otherwise known-good reference value, just pick one of the functions under test arbitrarily
int32_t ref = basicResult;
assert(ref == branchlessResult);
assert(ref == manualPrintResult);
assert(ref == moduloLookupResult);
assert(ref == moduloMultiplyResult);
assert(ref == charStackResult);
}
static uint64_t get_nanos() {
return std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::high_resolution_clock::now().time_since_epoch()).count();
}
static double get_time() {
return get_nanos() * (1 / 1'000'000'000.0);
}
template <int32_t (* func) (int32_t)>
static void benchmark_func(const char * name) {
//We want to avoid linearly scanning the range, because that will provide unrealistically good branch prediction
//but we don't want a full-on PRNG running in the core of the loop because that's not what we're benchmarking,
//so we just keep adding a large prime in hopes that it will provide a good enough approximation of randomness.
//TODO: A non-uniform distribution (something like logarithmic, where shorter digits strings are as common as longer ones) would be another thing to test.
static const int32_t stride = 923489569;
int32_t value = 0;
int runs = 10;
int iters = 1'000'000;
double minDuration = 1'000'000;
double maxDuration = 0;
double totalDuration = 0;
for (int run = 0; run < runs; ++run) {
double start = get_time();
for (int i = 0; i < iters; ++i) {
// Reversing digits may result in a value that doesn't reverse back to the original (namely on values with trailing zeros)
// Unless you reverse at least once before-hand (i.e. 120 reverses to 21 reverses to 12 an back to 21)
// We use this property to both validate the function results AND provide a means to avoid optimizing away the function calls.
int32_t result = func(value);
int32_t doubleResult = func(result);
int32_t tripleResult = func(doubleResult);
// This has to be here to make use of the values and ensure they're not optimized out.
// if (result != tripleResult) print_log("%d -> %d -> %d -> %d\n", value, result, doubleResult, tripleResult);
assert(result == tripleResult);
__builtin_sadd_overflow(value, stride, &value); //`sadd` in this case is short for "signed add", not "saturated add"
}
double duration = get_time() - start;
minDuration = fminf(minDuration, duration);
maxDuration = fmaxf(maxDuration, duration);
totalDuration += duration;
}
print_log("%20s: %d iters in %fs min, %fs max, %fs avg\n", name, iters, minDuration, maxDuration, totalDuration / runs);
}
void main() {
// These serve as both validation and process warmup
validate_different_outputs(-1'987'654'321);
validate_different_outputs(256);
validate_different_outputs(-256);
validate_different_outputs(12'345);
validate_different_outputs(25);
validate_different_outputs(-25);
validate_different_outputs(2);
validate_different_outputs(-2);
validate_different_outputs(1);
validate_different_outputs(-1);
validate_different_outputs(0);
validate_different_outputs(10);
validate_different_outputs(9);
validate_different_outputs(1'000'000'003);
validate_different_outputs(-1'000'000'003);
validate_different_outputs(INT32_MIN);
validate_different_outputs(INT32_MIN + 1);
validate_different_outputs(INT32_MAX);
validate_different_outputs(INT32_MAX - 1);
validate_different_outputs(2'000'000'008);
validate_different_outputs(-2'000'000'008);
validate_different_outputs(1'463'847'412);
validate_different_outputs(-1'463'847'412);
validate_different_outputs(-1524498589); //a case that caught an assert in a previous version of branchless modulo
print_log("DONE WITH TESTS\n");
benchmark_func<modulo_branchless>("modulo branchless");
benchmark_func<modulo_basic>("modulo basic");
benchmark_func<manual_print>("manual print");
benchmark_func<reverseDigits_ModuloLookup>("modulo lookup");
benchmark_func<reverseDigits_ModuloMultiply>("modulo multiply");
benchmark_func<reverseDigits_CharArrayStack>("stack print");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment