Skip to content

Instantly share code, notes, and snippets.

@ridiculousfish
Created January 29, 2023 02:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ridiculousfish/42639d90821f35de383fa1647044fcbf to your computer and use it in GitHub Desktop.
Save ridiculousfish/42639d90821f35de383fa1647044fcbf to your computer and use it in GitHub Desktop.
#include <limits.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <chrono>
#include <random>
/*
* Perform a narrowing division: 128 / 64 -> 64, and 64 / 32 -> 32.
* The dividend's low and high words are given by \p numhi and \p numlo,
* respectively. The divisor is given by \p den. \return the quotient, and the
* remainder by reference in \p r, if not null. If the quotient would require
* more than 64 bits, or if denom is 0, then return the max value for both
* quotient and remainder.
*
* These functions are released into the public domain, where applicable, or the
* CC0 license.
*/
__attribute__((noinline)) uint64_t divllu_branch(uint64_t numhi, uint64_t numlo,
uint64_t den, uint64_t *r) {
// We work in base 2**32.
// A uint32 holds a single digit. A uint64 holds two digits.
// Our numerator is conceptually [num3, num2, num1, num0].
// Our denominator is [den1, den0].
const uint64_t b = (1ull << 32);
// The high and low digits of our computed quotient.
uint32_t q1;
uint32_t q0;
// The normalization shift factor.
int shift;
// The high and low digits of our denominator (after normalizing).
// Also the low 2 digits of our numerator (after normalizing).
uint32_t den1;
uint32_t den0;
uint32_t num1;
uint32_t num0;
// A partial remainder.
uint64_t rem;
// The estimated quotient, and its corresponding remainder.
uint64_t qhat;
uint64_t rhat;
// Variables used to correct the quotient and remainder.
uint64_t remd1;
uint64_t remd0;
uint32_t qcorr;
// Check for overflow and divide by 0.
if (numhi >= den) {
if (r != NULL)
*r = ~0ull;
return ~0ull;
}
// Determine the normalization factor. We multiply den by this, so that its
// leading digit is at least half b. In binary this means just shifting left
// by the number of leading zeros, so that there's a 1 in the MSB. We also
// shift numer by the same amount. This cannot overflow because numhi < den.
// The expression (-shift & 63) is the same as (64 - shift), except it avoids
// the UB of shifting by 64. The funny bitwise 'and' ensures that numlo does
// not get shifted into numhi if shift is 0. clang 11 has an x86 codegen bug
// here: see LLVM bug 50118. The sequence below avoids it.
shift = __builtin_clzll(den);
den <<= shift;
numhi <<= shift;
if (shift > 0)
numhi |= (numlo >> (64 - shift));
numlo <<= shift;
// Extract the low digits of the numerator and both digits of the denominator.
num1 = (uint32_t)(numlo >> 32);
num0 = (uint32_t)(numlo & 0xFFFFFFFFu);
den1 = (uint32_t)(den >> 32);
den0 = (uint32_t)(den & 0xFFFFFFFFu);
// We wish to compute q1 = [n3 n2 n1] / [d1 d0].
// Estimate q1 as [n3 n2] / [d1], and then correct it.
// Note while qhat may be 2 digits, q1 is always 1 digit.
qhat = numhi / den1;
rhat = numhi % den1;
// Estimate the true remainder.
remd1 = rhat * b + num1;
remd0 = qhat * den0;
rem = remd1 - remd0;
// Correct both qhat and remainder.
if (remd0 > remd1) {
qcorr = (remd0 - remd1 > den);
qhat -= (qcorr + 1);
rem += den << qcorr;
}
q1 = (uint32_t)qhat;
// We wish to compute q0 = [rem1 rem0 n0] / [d1 d0].
// Estimate q0 as [rem1 rem0] / [d1] and correct it.
qhat = rem / den1;
rhat = rem % den1;
// Estimate the true remainder.
remd1 = rhat * b + num0;
remd0 = qhat * den0;
rem = remd1 - remd0;
if (remd0 > remd1) {
qcorr = (remd0 - remd1 > den);
qhat -= (qcorr + 1);
rem += den << qcorr;
}
q0 = (uint32_t)qhat;
// Return remainder if requested.
if (r != NULL)
*r = rem >> shift;
return ((uint64_t)q1 << 32) | q0;
}
__attribute__((noinline)) uint64_t divllu_mine(uint64_t numhi, uint64_t numlo,
uint64_t den, uint64_t *r) {
// We work in base 2**32.
// A uint32 holds a single digit. A uint64 holds two digits.
// Our numerator is conceptually [num3, num2, num1, num0].
// Our denominator is [den1, den0].
const uint64_t b = (1ull << 32);
// The high and low digits of our computed quotient.
uint32_t q1;
uint32_t q0;
// The normalization shift factor.
int shift;
// The high and low digits of our denominator (after normalizing).
// Also the low 2 digits of our numerator (after normalizing).
uint32_t den1;
uint32_t den0;
uint32_t num1;
uint32_t num0;
// A partial remainder.
uint64_t rem;
// The estimated quotient, and its corresponding remainder.
uint64_t qhat;
uint64_t rhat;
// Variables used to correct the quotient and remainder.
uint64_t remd1;
uint64_t remd0;
uint32_t qcorr;
// Check for overflow and divide by 0.
if (numhi >= den) {
if (r != NULL)
*r = ~0ull;
return ~0ull;
}
// Determine the normalization factor. We multiply den by this, so that its
// leading digit is at least half b. In binary this means just shifting left
// by the number of leading zeros, so that there's a 1 in the MSB. We also
// shift numer by the same amount. This cannot overflow because numhi < den.
// The expression (-shift & 63) is the same as (64 - shift), except it avoids
// the UB of shifting by 64. The funny bitwise 'and' ensures that numlo does
// not get shifted into numhi if shift is 0. clang 11 has an x86 codegen bug
// here: see LLVM bug 50118. The sequence below avoids it.
shift = __builtin_clzll(den);
den <<= shift;
numhi <<= shift;
numhi |= (numlo >> (-shift & 63)) & (-(int64_t)shift >> 63);
numlo <<= shift;
// Extract the low digits of the numerator and both digits of the denominator.
num1 = (uint32_t)(numlo >> 32);
num0 = (uint32_t)(numlo & 0xFFFFFFFFu);
den1 = (uint32_t)(den >> 32);
den0 = (uint32_t)(den & 0xFFFFFFFFu);
// We wish to compute q1 = [n3 n2 n1] / [d1 d0].
// Estimate q1 as [n3 n2] / [d1], and then correct it.
// Note while qhat may be 2 digits, q1 is always 1 digit.
qhat = numhi / den1;
rhat = numhi % den1;
// Estimate the true remainder.
remd1 = rhat * b + num1;
remd0 = qhat * den0;
rem = remd1 - remd0;
// Correct both qhat and remainder.
if (remd0 > remd1) {
qcorr = (remd0 - remd1 > den);
qhat -= (qcorr + 1);
rem += den << qcorr;
}
q1 = (uint32_t)qhat;
// We wish to compute q0 = [rem1 rem0 n0] / [d1 d0].
// Estimate q0 as [rem1 rem0] / [d1] and correct it.
qhat = rem / den1;
rhat = rem % den1;
// Estimate the true remainder.
remd1 = rhat * b + num0;
remd0 = qhat * den0;
rem = remd1 - remd0;
if (remd0 > remd1) {
qcorr = (remd0 - remd1 > den);
qhat -= (qcorr + 1);
rem += den << qcorr;
}
q0 = (uint32_t)qhat;
// Return remainder if requested.
if (r != NULL)
*r = rem >> shift;
return ((uint64_t)q1 << 32) | q0;
}
__attribute__((noinline)) uint64_t divllu_orig(uint64_t u1, uint64_t u0,
uint64_t v, uint64_t *r) {
const uint64_t b = (1ULL << 32); // Number base (16 bits).
uint64_t un1, un0, // Norm. dividend LSD's.
vn1, vn0, // Norm. divisor digits.
q1, q0, // Quotient digits.
un64, un21, un10, // Dividend digit pairs.
rhat; // A remainder.
int s; // Shift amount for norm.
if (u1 >= v) { // If overflow, set rem.
if (r != NULL) // to an impossible value,
*r = (uint64_t)(-1); // and return the largest
return (uint64_t)(-1);
} // possible quotient.
/* count leading zeros */
s = __builtin_clzll(v); // 0 <= s <= 63.
v = v << s; // Normalize divisor.
vn1 = v >> 32; // Break divisor up into
vn0 = v & 0xFFFFFFFF; // two 32-bit digits.
un64 = (u1 << s) | ((u0 >> (64 - s)) & (-s >> 31));
un10 = u0 << s; // Shift dividend left.
un1 = un10 >> 32; // Break right half of
un0 = un10 & 0xFFFFFFFF; // dividend into two digits.
q1 = un64 / vn1; // Compute the first
rhat = un64 - q1 * vn1; // quotient digit, q1.
again1:
if (q1 >= b || q1 * vn0 > b * rhat + un1) {
q1 = q1 - 1;
rhat = rhat + vn1;
if (rhat < b)
goto again1;
}
un21 = un64 * b + un1 - q1 * v; // Multiply and subtract.
q0 = un21 / vn1; // Compute the second
rhat = un21 - q0 * vn1; // quotient digit, q0.
again2:
if (q0 >= b || q0 * vn0 > b * rhat + un0) {
q0 = q0 - 1;
rhat = rhat + vn1;
if (rhat < b)
goto again2;
}
if (r != NULL) // If remainder is wanted,
*r = (un21 * b + un0 - q0 * v) >> s; // return it.
return q1 * b + q0;
}
#if defined(__x86_64__)
__attribute__((noinline)) uint64_t divllu_asm(uint64_t numhi, uint64_t numlo,
uint64_t den, uint64_t *r) {
uint64_t result;
__asm__("divq %[v]"
: "=a"(result), "=d"(*r)
: [v] "r"(den), "a"(numlo), "d"(numhi));
return result;
}
#endif
__attribute__((noinline)) uint64_t divllu_nat(uint64_t numhi, uint64_t numlo,
uint64_t den, uint64_t *r) {
__uint128_t num = ((__uint128_t)numhi << 64) | numlo;
uint64_t div = (uint64_t)(num / den);
uint64_t rem = (uint64_t)(num % den);
if (r != NULL)
*r = rem;
return div;
}
constexpr size_t CASE_COUNT = 1 << 16;
constexpr size_t ITER_COUNT = 1 << 11;
struct case_t {
uint64_t numerhi;
uint64_t numerlo;
uint64_t denom;
bool valid() const { return denom > 0 && numerhi < denom; }
};
static const case_t *make_cases() {
std::mt19937_64 mt(std::mt19937_64::default_seed);
case_t *cases = new case_t[CASE_COUNT];
for (size_t i = 0; i < CASE_COUNT; i++) {
do {
cases[i].numerlo = mt();
cases[i].numerhi = mt();
cases[i].denom = mt();
} while (!cases[i].valid());
}
return cases;
}
using divider_t = uint64_t (*)(uint64_t numhi, uint64_t numlo, uint64_t den,
uint64_t *r);
__attribute__((noinline)) static uint64_t time_function(const case_t *cases,
divider_t div) {
uint64_t sum = 0;
for (size_t i = 0; i < CASE_COUNT; i++) {
uint64_t quot;
uint64_t rem;
quot = div(cases[i].numerhi, cases[i].numerlo, cases[i].denom, &rem);
sum += quot;
sum += rem;
}
return sum;
}
int main(int argc, char *argv[]) {
int onlycase = 0;
if (argv[1])
onlycase = atoi(argv[1]);
const struct {
const char *name;
divider_t func;
} dividers[] = {
{"hackers", divllu_orig},
{"libdiv org", divllu_mine},
{"libdiv brn", divllu_branch},
{"libdiv nat", divllu_nat},
#if defined(__x86_64__)
{"divq", divllu_asm},
#endif
};
const case_t *cases = make_cases();
using namespace std::chrono;
const uint64_t exp = time_function(cases, divllu_orig);
int whichcase = 0;
for (const auto &d : dividers) {
whichcase++;
if (onlycase && whichcase != onlycase)
continue;
uint64_t best = std::numeric_limits<uint64_t>::max();
for (size_t i = 0; i < ITER_COUNT; i++) {
using namespace std::chrono;
auto t1 = high_resolution_clock::now();
uint64_t res = time_function(cases, d.func);
auto t2 = high_resolution_clock::now();
uint64_t nanos = duration_cast<std::chrono::nanoseconds>(t2 - t1).count();
best = std::min(best, nanos);
if (res != exp)
abort();
}
double nsec = (double)best / (double)CASE_COUNT;
printf("%18s\t%4.4f\n", d.name, nsec);
}
return 0;
}
# divllu benchmark of normalization step
# Lower is better.
# org: original branchfree
# brn: uses if statement
# nat: compiler-generated uint128 divide
# divq: hardware 128 bit divide
# M1 Max
> clang++ -std=c++14 -O3 divllu_benchmark.cpp ; ./a.out
hackers 11.6113
libdiv org 8.9404
libdiv brn 8.9823
libdiv nat 13.5206
# Ryzen 9 5900X, clang 14
> clang++ -std=c++14 -O3 divlu_benchmark.cpp; ./a.out
hackers 10.4827
libdiv org 8.8335
libdiv brn 9.0491
libdiv nat 3.6855
divq 2.4680
# Ryzen 9 5900X, g++ 11.3
> g++ -std=c++14 -O3 divlu_benchmark.cpp; ./a.out
hackers 14.1364
libdiv org 8.9577
libdiv brn 9.6478
libdiv nat 3.7431
divq 2.4856
# Intel i7-8700, g++ 12.2
> g++ -std=c++14 -O3 divlu_benchmark.cpp; ./a.out
hackers 22.8676
libdiv org 21.3373
libdiv brn 21.9065
libdiv nat 21.4691
divq 19.2605
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment