Skip to content

Instantly share code, notes, and snippets.

@ridiculousfish
Last active September 15, 2021 21:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ridiculousfish/4b0bd18abce4347ce2b32c2921f9ab22 to your computer and use it in GitHub Desktop.
Save ridiculousfish/4b0bd18abce4347ce2b32c2921f9ab22 to your computer and use it in GitHub Desktop.
#include <limits.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <chrono>
#include <random>
__attribute__((noinline)) uint64_t divllu_v2_nobranch(
uint64_t numhi, uint64_t numlo, uint64_t den, uint64_t *r) {
// We work in base 2**32.
// A uint32 holds a single digit. A uint64 holds two digits.
// Our numerator is conceptually [num3, num2, num1, num0].
// Our denominator is [den1, den0].
const uint64_t b = (1ull << 32);
// The high and low digits of our computed quotient.
uint32_t q1;
uint32_t q0;
// The normalization shift factor.
int shift;
// The high and low digits of our denominator (after normalizing).
// Also the low 2 digits of our numerator (after normalizing).
uint32_t den1;
uint32_t den0;
uint32_t num1;
uint32_t num0;
// A partial remainder.
uint64_t rem;
// The estimated quotient, and its corresponding remainder.
uint64_t qhat;
uint64_t rhat;
// Variables used to correct the quotient and remainder.
uint64_t remd1;
uint64_t remd0;
uint64_t qexcess1;
uint64_t qexcess2;
// Check for overflow and divide by 0.
if (numhi >= den) {
if (r != NULL) *r = ~0ull;
return ~0ull;
}
// Determine the normalization factor. We multiply den by this, so that its leading digit is at
// least half b. In binary this means just shifting left by the number of leading zeros, so that
// there's a 1 in the MSB.
// We also shift numer by the same amount. This cannot overflow because numhi < den.
// The expression (-shift & 63) is the same as (64 - shift), except it avoids the UB of shifting
// by 64. The funny bitwise 'and' ensures that numlo does not get shifted into numhi if shift is
// 0. clang 11 has an x86 codegen bug here: see LLVM bug 50118. The sequence below avoids it.
shift = __builtin_clzll(den);
den <<= shift;
numhi <<= shift;
numhi |= (numlo >> (-shift & 63)) & (-(int64_t)shift >> 63);
numlo <<= shift;
// Extract the low digits of the numerator and both digits of the denominator.
num1 = (uint32_t)(numlo >> 32);
num0 = (uint32_t)(numlo & 0xFFFFFFFFu);
den1 = (uint32_t)(den >> 32);
den0 = (uint32_t)(den & 0xFFFFFFFFu);
// We wish to compute q1 = [n3 n2 n1] / [d1 d0].
// Estimate q1 as [n3 n2] / [d1], and then correct it.
// Note while qhat may be 2 digits, q1 is always 1 digit.
qhat = numhi / den1;
rhat = numhi % den1;
// Estimate the true remainder.
remd1 = rhat * b + num1;
remd0 = qhat * den0;
rem = remd1 - remd0;
// Correct both qhat and remainder.
qexcess1 = -(rem > remd1); // remd0 > remd1
qexcess2 = qexcess1 & (-rem > den); // remd0 - remd1 > den
qhat += qexcess1;
qhat -= qexcess2;
rem += (den & qexcess1) << qexcess2;
q1 = (uint32_t)qhat;
// We wish to compute q0 = [rem1 rem0 n0] / [d1 d0].
// Estimate q0 as [rem1 rem0] / [d1] and correct it.
qhat = rem / den1;
rhat = rem % den1;
// Estimate the true remainder.
remd1 = rhat * b + num0;
remd0 = qhat * den0;
rem = remd1 - remd0;
qexcess1 = -(rem > remd1); // remd0 > remd1
qexcess2 = qexcess1 & -(-rem > den); // remd0 - remd1 > den
qhat += qexcess1 + qexcess2;
q0 = (uint32_t)qhat;
// Return remainder if requested.
if (r != NULL) {
rem += (den & qexcess1) + (den & qexcess2);
*r = rem >> shift;
}
return ((uint64_t)q1 << 32) | q0;
}
__attribute__((noinline)) uint64_t divllu_v2_branch(
uint64_t numhi, uint64_t numlo, uint64_t den, uint64_t *r) {
// We work in base 2**32.
// A uint32 holds a single digit. A uint64 holds two digits.
// Our numerator is conceptually [num3, num2, num1, num0].
// Our denominator is [den1, den0].
const uint64_t b = (1ull << 32);
// The high and low digits of our computed quotient.
uint32_t q1;
uint32_t q0;
// The normalization shift factor.
int shift;
// The high and low digits of our denominator (after normalizing).
// Also the low 2 digits of our numerator (after normalizing).
uint32_t den1;
uint32_t den0;
uint32_t num1;
uint32_t num0;
// A partial remainder.
uint64_t rem;
// The estimated quotient, and its corresponding remainder.
uint64_t qhat;
uint64_t rhat;
// Variables used to correct the quotient and remainder.
uint64_t remd1;
uint64_t remd0;
uint32_t qcorr;
// Check for overflow and divide by 0.
if (numhi >= den) {
if (r != NULL) *r = ~0ull;
return ~0ull;
}
// Determine the normalization factor. We multiply den by this, so that its leading digit is at
// least half b. In binary this means just shifting left by the number of leading zeros, so that
// there's a 1 in the MSB.
// We also shift numer by the same amount. This cannot overflow because numhi < den.
// The expression (-shift & 63) is the same as (64 - shift), except it avoids the UB of shifting
// by 64. The funny bitwise 'and' ensures that numlo does not get shifted into numhi if shift is
// 0. clang 11 has an x86 codegen bug here: see LLVM bug 50118. The sequence below avoids it.
shift = __builtin_clzll(den);
den <<= shift;
numhi <<= shift;
numhi |= (numlo >> (-shift & 63)) & (-(int64_t)shift >> 63);
numlo <<= shift;
// Extract the low digits of the numerator and both digits of the denominator.
num1 = (uint32_t)(numlo >> 32);
num0 = (uint32_t)(numlo & 0xFFFFFFFFu);
den1 = (uint32_t)(den >> 32);
den0 = (uint32_t)(den & 0xFFFFFFFFu);
// We wish to compute q1 = [n3 n2 n1] / [d1 d0].
// Estimate q1 as [n3 n2] / [d1], and then correct it.
// Note while qhat may be 2 digits, q1 is always 1 digit.
qhat = numhi / den1;
rhat = numhi % den1;
// Estimate the true remainder.
remd1 = rhat * b + num1;
remd0 = qhat * den0;
rem = remd1 - remd0;
// Correct both qhat and remainder.
if (remd0 > remd1) {
qcorr = (remd0 - remd1 > den);
qhat -= (qcorr + 1);
rem += den << qcorr;
}
q1 = (uint32_t)qhat;
// We wish to compute q0 = [rem1 rem0 n0] / [d1 d0].
// Estimate q0 as [rem1 rem0] / [d1] and correct it.
qhat = rem / den1;
rhat = rem % den1;
// Estimate the true remainder.
remd1 = rhat * b + num0;
remd0 = qhat * den0;
rem = remd1 - remd0;
if (remd0 > remd1) {
qcorr = (remd0 - remd1 > den);
qhat -= (qcorr + 1);
rem += den << qcorr;
}
q0 = (uint32_t)qhat;
// Return remainder if requested.
if (r != NULL) *r = rem >> shift;
return ((uint64_t)q1 << 32) | q0;
}
__attribute__((noinline)) uint64_t divllu_mul(
uint64_t numhi, uint64_t numlo, uint64_t den, uint64_t *r) {
// We work in base 2**32.
// A uint32 holds a single digit. A uint64 holds two digits.
// Our numerator is conceptually [num3, num2, num1, num0].
// Our denominator is [den1, den0].
const uint64_t b = (1ull << 32);
// The high and low digits of our computed quotient.
uint32_t q1;
uint32_t q0;
// The normalization shift factor.
int shift;
// The high and low digits of our denominator (after normalizing).
// Also the low 2 digits of our numerator (after normalizing).
uint32_t den1;
uint32_t den0;
uint32_t num1;
uint32_t num0;
// A partial remainder.
uint64_t rem;
// The estimated quotient, and its corresponding remainder (unrelated to true remainder).
uint64_t qhat;
uint64_t rhat;
// Variables used to correct the estimated quotient.
uint64_t c1;
uint64_t c2;
// Check for overflow and divide by 0.
if (numhi >= den) {
if (r != NULL) *r = ~0ull;
return ~0ull;
}
// Determine the normalization factor. We multiply den by this, so that its leading digit is at
// least half b. In binary this means just shifting left by the number of leading zeros, so that
// there's a 1 in the MSB.
// We also shift numer by the same amount. This cannot overflow because numhi < den.
// The expression (-shift & 63) is the same as (64 - shift), except it avoids the UB of shifting
// by 64. The funny bitwise 'and' ensures that numlo does not get shifted into numhi if shift is
// 0. clang 11 has an x86 codegen bug here: see LLVM bug 50118. The sequence below avoids it.
shift = __builtin_clzll(den);
den <<= shift;
numhi <<= shift;
numhi |= (numlo >> (-shift & 63)) & (-(int64_t)shift >> 63);
numlo <<= shift;
// Extract the low digits of the numerator and both digits of the denominator.
num1 = (uint32_t)(numlo >> 32);
num0 = (uint32_t)(numlo & 0xFFFFFFFFu);
den1 = (uint32_t)(den >> 32);
den0 = (uint32_t)(den & 0xFFFFFFFFu);
// We wish to compute q1 = [n3 n2 n1] / [d1 d0].
// Estimate q1 as [n3 n2] / [d1], and then correct it.
// Note while qhat may be 2 digits, q1 is always 1 digit.
qhat = numhi / den1;
rhat = numhi % den1;
c1 = qhat * den0;
c2 = rhat * b + num1;
if (c1 > c2) qhat -= (c1 - c2 > den) ? 2 : 1;
q1 = (uint32_t)qhat;
// Compute the true (partial) remainder.
rem = numhi * b + num1 - q1 * den;
// We wish to compute q0 = [rem1 rem0 n0] / [d1 d0].
// Estimate q0 as [rem1 rem0] / [d1] and correct it.
qhat = rem / den1;
rhat = rem % den1;
c1 = qhat * den0;
c2 = rhat * b + num0;
if (c1 > c2) qhat -= (c1 - c2 > den) ? 2 : 1;
q0 = (uint32_t)qhat;
// Return remainder if requested.
if (r != NULL) *r = (rem * b + num0 - q0 * den) >> shift;
return ((uint64_t)q1 << 32) | q0;
}
__attribute__((noinline)) uint64_t divlu_orig(uint64_t u1, uint64_t u0, uint64_t v, uint64_t *r) {
const uint64_t b = (1ULL << 32); // Number base (16 bits).
uint64_t un1, un0, // Norm. dividend LSD's.
vn1, vn0, // Norm. divisor digits.
q1, q0, // Quotient digits.
un64, un21, un10, // Dividend digit pairs.
rhat; // A remainder.
int s; // Shift amount for norm.
if (u1 >= v) { // If overflow, set rem.
if (r != NULL) // to an impossible value,
*r = (uint64_t)(-1); // and return the largest
return (uint64_t)(-1);
} // possible quotient.
/* count leading zeros */
s = __builtin_clzll(v); // 0 <= s <= 63.
v = v << s; // Normalize divisor.
vn1 = v >> 32; // Break divisor up into
vn0 = v & 0xFFFFFFFF; // two 32-bit digits.
un64 = (u1 << s) | ((u0 >> (64 - s)) & (-s >> 31));
un10 = u0 << s; // Shift dividend left.
un1 = un10 >> 32; // Break right half of
un0 = un10 & 0xFFFFFFFF; // dividend into two digits.
q1 = un64 / vn1; // Compute the first
rhat = un64 - q1 * vn1; // quotient digit, q1.
again1:
if (q1 >= b || q1 * vn0 > b * rhat + un1) {
q1 = q1 - 1;
rhat = rhat + vn1;
if (rhat < b) goto again1;
}
un21 = un64 * b + un1 - q1 * v; // Multiply and subtract.
q0 = un21 / vn1; // Compute the second
rhat = un21 - q0 * vn1; // quotient digit, q0.
again2:
if (q0 >= b || q0 * vn0 > b * rhat + un0) {
q0 = q0 - 1;
rhat = rhat + vn1;
if (rhat < b) goto again2;
}
if (r != NULL) // If remainder is wanted,
*r = (un21 * b + un0 - q0 * v) >> s; // return it.
return q1 * b + q0;
}
#if defined(__x86_64__)
__attribute__((noinline)) uint64_t divllu_asm(
uint64_t numhi, uint64_t numlo, uint64_t den, uint64_t *r) {
uint64_t result;
__asm__("divq %[v]" : "=a"(result), "=d"(*r) : [v] "r"(den), "a"(numlo), "d"(numhi));
return result;
}
#endif
constexpr size_t CASE_COUNT = 1 << 14;
constexpr size_t ITER_COUNT = 1 << 10;
struct case_t {
uint64_t numerhi;
uint64_t numerlo;
uint64_t denom;
bool valid() const { return denom > 0 && numerhi < denom; }
};
static const case_t *make_cases() {
std::mt19937_64 mt(std::mt19937_64::default_seed);
case_t *cases = new case_t[CASE_COUNT];
for (size_t i = 0; i < CASE_COUNT; i++) {
do {
cases[i].numerlo = mt();
cases[i].numerhi = mt();
cases[i].denom = mt();
} while (!cases[i].valid());
}
return cases;
}
using divider_t = uint64_t (*)(uint64_t numhi, uint64_t numlo, uint64_t den, uint64_t *r);
__attribute__((noinline)) static uint64_t time_function(const case_t *cases, divider_t div) {
uint64_t sum = 0;
for (size_t i = 0; i < CASE_COUNT; i++) {
uint64_t quot;
uint64_t rem;
quot = div(cases[i].numerhi, cases[i].numerlo, cases[i].denom, &rem);
sum += quot;
sum += rem;
}
return sum;
}
int main(int argc, char *argv[]) {
int onlycase = 0;
if (argv[1]) onlycase = atoi(argv[1]);
const struct {
const char *name;
divider_t func;
} dividers[] = {
{"hackers", divlu_orig},
{"libdiv mul", divllu_mul},
{"libdiv sub", divllu_v2_branch},
{"libdiv sub bf", divllu_v2_branch},
#if defined(__x86_64__)
{"divq", divllu_asm},
#endif
};
const case_t *cases = make_cases();
using namespace std::chrono;
const uint64_t exp = time_function(cases, divlu_orig);
int whichcase = 0;
for (const auto &d : dividers) {
whichcase++;
if (onlycase && whichcase != onlycase) continue;
uint64_t best = std::numeric_limits<uint64_t>::max();
for (size_t i = 0; i < ITER_COUNT; i++) {
using namespace std::chrono;
auto t1 = high_resolution_clock::now();
uint64_t res = time_function(cases, d.func);
auto t2 = high_resolution_clock::now();
uint64_t nanos = duration_cast<std::chrono::nanoseconds>(t2 - t1).count();
best = std::min(best, nanos);
if (res != exp) abort();
}
double nsec = (double)best / (double)CASE_COUNT;
printf("%18s\t%4.1f\n", d.name, nsec);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment