Skip to content

Instantly share code, notes, and snippets.

@mdickinson
Last active March 10, 2024 19:24
Show Gist options
  • Save mdickinson/e087001d213725a93eeb8d8f447a2f40 to your computer and use it in GitHub Desktop.
Save mdickinson/e087001d213725a93eeb8d8f447a2f40 to your computer and use it in GitHub Desktop.
Fast square root of a 64-bit square number
// Exact square root of a known square integer
// ===========================================
// This snippet contains a function `isqrt64_exact` with signature
//
// uint32_t isqrt64_exact(uint64_t n);
//
// `isqrt64_exact` computes the 32-bit square root of its unsigned 64-bit
// integer input, under the assumption that that input is a perfect square.
//
// Compile under gcc or clang with:
//
// gcc -std=c99 -Wall -O3 -march=native sqrt_exact.c -o sqrt_exact
//
// The code includes a `main` function that performs exhaustive testing on all
// 2^32 square numbers smaller than 2^64. In the absence of bugs, executing
// `sqrt_exact` should run the exhaustive test (expect it to take a few tens of
// seconds to complete) and print a single line saying "Success!".
//
// For non-perfect-square inputs, the result is deterministic and well-defined,
// but probably not very useful for typical applications.
//
// The algorithm used is division-free and (aside from an initial early return
// for an input of 0) branch-free, using only simple arithmetic and bitwise
// operations. As a result, it's fast. The exhaustive testing loop completes in
// around 14.2 seconds on my Intel Core i7-8559U laptop, which works out at
// ~3.3ns per square root call, or about 14.9 clock cycles per square root
// at 4.5 GHz.
//
// Portability notes:
//
// - The code uses GCC's __builtin_ctzl to count trailing zeros. For MSVC,
// _BitScanForward can be used instead. For more portable options, see for
// example
// https://graphics.stanford.edu/~seander/bithacks.html#ZerosOnRightLinear
// - The code is technically not correct on platforms whose `int` type has
// width greater than 32 bits. Such platforms are rare in practice, but on
// those platforms, `uint32_t` operands in arithmetic operations will be
// promoted to `int`, and then operations will be performed using type `int`,
// giving undefined behaviour from signed overflow. On such platforms, use
// `unsigned int` instead of `uint32_t` for intermediate results.
//
// For purists: it's not hard to adapt the isqrt64_exact code to be
// *completely* branch free, completely portable, and to eliminate the lookup
// table entirely. Not surprisingly, the result is slower, at least on my
// machine.
// Description of algorithm: mathematical background
// =================================================
// In a nutshell, we create successively better 2-adic approximations to the
// inverse square root of the odd part of the input, using a Newton–Raphson
// iteration. We then turn that into a sufficiently good 2-adic approximation
// to the square root of the odd part of the input.
//
// In more detail: for a fixed real number n define a function f, valid for all
// x != -1/2, by
//
// f(x) = n - (1 / (2x + 1)^2)
//
// If n is positive then f has two roots, namely (±1/√n - 1) / 2. The
// Newton–Raphson method applied to f gives us an iteration
//
// (1) x ↦ x - ((x^2 + x)n + (n - 1)/4) (2x + 1).
//
// The term (n - 1) / 4 will appear frequently in what follows, so for
// convenience we define k = (n - 1) / 4, so that the iteration (1) can be
// rewritten as
//
// (1) x ↦ x - ((x^2 + x)n + k) (2x + 1).
//
// As usual, we expect this Newton–Raphson iteration to give us quadratic
// convergence near the roots. While the usual proof of quadratic convergence
// relies on calculus, for this particular f the quadratic convergence can also
// be shown purely algebraically. First note that f(x) = 0 is equivalent to the
// condition (x^2 + x)n + k = 0, so x is close to a root of f when (x^2 + x)n +
// k is small. Put
//
// y = x - ((x^2 + x)n + k) (2x + 1).
//
// Then by expanding both sides in terms of x and k, it's easy to verify that:
//
// (2) (y^2 + y)n + k = ((x^2 + x)n + k)^2 ((2x+1)^2 n - 4)
//
// So if (x^2 + x)n + k is sufficiently small, then (y^2 + y)n + k is smaller
// still, thanks to the squared term. (The extra term ((2x+1)^2 n - 4)
// converges to the constant value -3 as the iteration progresses.)
//
// Now we switch domains. The key observation is that the iteration (1), and
// the quadratic convergence arising from (2), are just as valid in the 2-adic
// integers as in the real numbers. Suppose that n is an integer congruent to 1
// modulo 4, that k = (n - 1) / 4 as before, and that x is an integer
// satisfying
//
// (3) (x^2 + x)n + k ≡ 0 (mod 2^j)
//
// for some j >= 0. Set y = x - ((x^2 + x)n + k) (2x + 1) as before. Then
// from (2) we have
//
// (4) (y^2 + y)n + k ≡ 0 (mod 2^2j).
//
// So if we have a solution to the congruence (x^2 + x)n + k ≡ 0 (mod 2^8)
// (for example), a single application of the iteration (1) gets us a solution
// modulo 2^16, and a second application gives us a solution modulo 2^32. To
// start the iteration off it's enough to note that if n ≡ 1 (mod 8) then
// _every_ integer x is a solution to (x^2 + x)n + k ≡ 0 (mod 2), so we can
// always use x = 0 as a starting guess. But for practical implementation it
// will typically be faster to replace the initial iterations with a lookup
// table, and that's what we do in the code below, using a lookup table based
// on the least significant eight bits of k to provide a solution to
// (x^2 + x)n + k ≡ 0 (mod 2^8).
//
// Let's take stock. At this point we have an efficient division-free algorithm
// for finding solutions to (x^2 + x)n + k ≡ 0 (mod 2^32), for any integer n
// congruent to 1 modulo 8. We still need to explain how to use such a solution
// x to find a square root of n.
//
// Once we have an integer x satisfying
//
// (5) (x^2 + x)n + k ≡ 0 (mod 2^32),
//
// put b = nx + 2k. Then b^2 + b - k = ((x^2 + x)n + k)n, and so (5)
// implies that
//
// (6) b^2 + b ≡ k (mod 2^32).
//
// There are exactly two solutions b modulo 2^32 to the congruence (6), the
// other being -1-b. So by replacing b with -1-b if necessary, and reducing
// modulo 2^32, we can also assume that
//
// (7) 0 <= b < 2^31
//
// Now set a = 2b + 1, then (6) and (7) imply
//
// (8) a^2 ≡ n (mod 2^34)
//
// and
//
// (9) 0 < a < 2^32
//
// respectively. But if we know that n is an odd perfect square smaller than
// 2^64, (8) and (9) together are enough to guarantee that a is the square root
// of n. To see this, note that both a and √n are odd integers in the interval
// (0, 2^32), and that condition (8) implies that (a - √n)(a + √n) is divisible
// by 2^34. If both a - √n and a + √n are divisible by 4, then their sum 2a is
// also divisible by 4, so a is divisible by 2, in contradiction to a being
// odd. So one of the two factors is not divisible by 4, and the other must
// then be divisible by 2^33. But 0 < a + √n < 2^33, and so a + √n cannot be
// divisible by 2^33. So a - √n is divisible by 2^33, and from the bounds on a
// and √n it follows that a = √n.
// Implementation notes:
//
// * The above theory applies for odd n. For the general case, we first handle
// the special case of an input of 0. Now if n is a perfect square, it has an
// even number `j` of trailing zeros in its binary representation. By
// shifting the trailing zeros out we get an odd perfect square. Applying the
// theory above gives us a square root of that odd perfect square, and
// shifting that square root left by `j/2` bits gives the square root of the
// original `n`.
// * We rely on C's rules for uint32_t arithmetic (namely reduction modulo
// 2^32) of intermediate results. These reductions of course do not affect
// the validity of the congruences modulo 2^32.
// * We make significant use of the fact that working
// modulo 2^32, for a value `x` with type `uint32_t`, `~x` is `-1-x`.
// Thus our iteration (1) can be written in C as
//
// x += (n * x * ~x - k) * (x - ~x);
//
// * The line `b ^= -(b >> 31);` near the end of the function ensures that b <
// 2^31: it leaves b unchanged if it's smaller than 2^31, but replaces it
// with 2^32 - 1 - b (the other solution to the congruence (6)) if not.
#include <stdint.h>
#include <stdio.h>
static const uint8_t lut[128] = {
0, 85, 83, 102, 71, 2, 36, 126, 15, 37, 28, 22, 87, 50, 107, 46,
31, 10, 115, 57, 103, 98, 4, 33, 47, 58, 3, 118, 119, 109, 116, 113,
63, 106, 108, 38, 120, 61, 27, 62, 79, 101, 35, 41, 104, 13, 84, 17,
95, 53, 76, 121, 88, 34, 59, 97, 111, 5, 67, 54, 72, 82, 52, 78,
127, 42, 44, 25, 56, 125, 91, 1, 112, 90, 99, 105, 40, 77, 20, 81,
96, 117, 12, 70, 24, 29, 123, 94, 80, 69, 124, 9, 8, 18, 11, 14,
64, 21, 19, 89, 7, 66, 100, 65, 48, 26, 92, 86, 23, 114, 43, 110,
32, 74, 51, 6, 39, 93, 68, 30, 16, 122, 60, 73, 55, 45, 75, 49,
};
uint32_t isqrt64_exact(uint64_t n)
{
uint32_t m, k, x, b;
if (n == 0)
return 0;
int j = __builtin_ctzl(n);
n >>= j;
m = (uint32_t)n;
k = (uint32_t)(n >> 2);
x = lut[k >> 1 & 127];
x += (m * x * ~x - k) * (x - ~x);
x += (m * x * ~x - k) * (x - ~x);
b = m * x + 2 * k;
b ^= -(b >> 31);
return (b - ~b) << (j >> 1);
}
// Do exhaustive testing on all square numbers smaller than 2**64.
int main(void)
{
uint32_t a = 0;
do
{
uint64_t n = (uint64_t)a * a;
uint32_t b = isqrt64_exact(n);
if (a != b)
{
printf("Failure for a=%u, b=%u, n=%llu\n", a, b, n);
return 1;
}
a++;
} while (a);
printf("Success!\n");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment