Last active
March 10, 2024 19:24
-
-
Save mdickinson/e087001d213725a93eeb8d8f447a2f40 to your computer and use it in GitHub Desktop.
Fast square root of a 64-bit square number
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Exact square root of a known square integer | |
// =========================================== | |
// This snippet contains a function `isqrt64_exact` with signature | |
// | |
// uint32_t isqrt64_exact(uint64_t n); | |
// | |
// `isqrt64_exact` computes the 32-bit square root of its unsigned 64-bit | |
// integer input, under the assumption that that input is a perfect square. | |
// | |
// Compile under gcc or clang with: | |
// | |
// gcc -std=c99 -Wall -O3 -march=native sqrt_exact.c -o sqrt_exact | |
// | |
// The code includes a `main` function that performs exhaustive testing on all | |
// 2^32 square numbers smaller than 2^64. In the absence of bugs, executing | |
// `sqrt_exact` should run the exhaustive test (expect it to take a few tens of | |
// seconds to complete) and print a single line saying "Success!". | |
// | |
// For non-perfect-square inputs, the result is deterministic and well-defined, | |
// but probably not very useful for typical applications. | |
// | |
// The algorithm used is division-free and (aside from an initial early return | |
// for an input of 0) branch-free, using only simple arithmetic and bitwise | |
// operations. As a result, it's fast. The exhaustive testing loop completes in | |
// around 14.2 seconds on my Intel Core i7-8559U laptop, which works out at | |
// ~3.3ns per square root call, or about 14.9 clock cycles per square root | |
// at 4.5 GHz. | |
// | |
// Portability notes: | |
// | |
// - The code uses GCC's __builtin_ctzl to count trailing zeros. For MSVC, | |
// _BitScanForward can be used instead. For more portable options, see for | |
// example | |
// https://graphics.stanford.edu/~seander/bithacks.html#ZerosOnRightLinear | |
// - The code is technically not correct on platforms whose `int` type has | |
// width greater than 32 bits. Such platforms are rare in practice, but on | |
// those platforms, `uint32_t` operands in arithmetic operations will be | |
// promoted to `int`, and then operations will be performed using type `int`, | |
// giving undefined behaviour from signed overflow. On such platforms, use | |
// `unsigned int` instead of `uint32_t` for intermediate results. | |
// | |
// For purists: it's not hard to adapt the isqrt64_exact code to be | |
// *completely* branch free, completely portable, and to eliminate the lookup | |
// table entirely. Not surprisingly, the result is slower, at least on my | |
// machine. | |
// Description of algorithm: mathematical background | |
// ================================================= | |
// In a nutshell, we create successively better 2-adic approximations to the | |
// inverse square root of the odd part of the input, using a Newton–Raphson | |
// iteration. We then turn that into a sufficiently good 2-adic approximation | |
// to the square root of the odd part of the input. | |
// | |
// In more detail: for a fixed real number n define a function f, valid for all | |
// x != -1/2, by | |
// | |
// f(x) = n - (1 / (2x + 1)^2) | |
// | |
// If n is positive then f has two roots, namely (±1/√n - 1) / 2. The | |
// Newton–Raphson method applied to f gives us an iteration | |
// | |
// (1) x ↦ x - ((x^2 + x)n + (n - 1)/4) (2x + 1). | |
// | |
// The term (n - 1) / 4 will appear frequently in what follows, so for | |
// convenience we define k = (n - 1) / 4, so that the iteration (1) can be | |
// rewritten as | |
// | |
// (1) x ↦ x - ((x^2 + x)n + k) (2x + 1). | |
// | |
// As usual, we expect this Newton–Raphson iteration to give us quadratic | |
// convergence near the roots. While the usual proof of quadratic convergence | |
// relies on calculus, for this particular f the quadratic convergence can also | |
// be shown purely algebraically. First note that f(x) = 0 is equivalent to the | |
// condition (x^2 + x)n + k = 0, so x is close to a root of f when (x^2 + x)n + | |
// k is small. Put | |
// | |
// y = x - ((x^2 + x)n + k) (2x + 1). | |
// | |
// Then by expanding both sides in terms of x and k, it's easy to verify that: | |
// | |
// (2) (y^2 + y)n + k = ((x^2 + x)n + k)^2 ((2x+1)^2 n - 4) | |
// | |
// So if (x^2 + x)n + k is sufficiently small, then (y^2 + y)n + k is smaller | |
// still, thanks to the squared term. (The extra term ((2x+1)^2 n - 4) | |
// converges to the constant value -3 as the iteration progresses.) | |
// | |
// Now we switch domains. The key observation is that the iteration (1), and | |
// the quadratic convergence arising from (2), are just as valid in the 2-adic | |
// integers as in the real numbers. Suppose that n is an integer congruent to 1 | |
// modulo 4, that k = (n - 1) / 4 as before, and that x is an integer | |
// satisfying | |
// | |
// (3) (x^2 + x)n + k ≡ 0 (mod 2^j) | |
// | |
// for some j >= 0. Set y = x - ((x^2 + x)n + k) (2x + 1) as before. Then | |
// from (2) we have | |
// | |
// (4) (y^2 + y)n + k ≡ 0 (mod 2^2j). | |
// | |
// So if we have a solution to the congruence (x^2 + x)n + k ≡ 0 (mod 2^8) | |
// (for example), a single application of the iteration (1) gets us a solution | |
// modulo 2^16, and a second application gives us a solution modulo 2^32. To | |
// start the iteration off it's enough to note that if n ≡ 1 (mod 8) then | |
// _every_ integer x is a solution to (x^2 + x)n + k ≡ 0 (mod 2), so we can | |
// always use x = 0 as a starting guess. But for practical implementation it | |
// will typically be faster to replace the initial iterations with a lookup | |
// table, and that's what we do in the code below, using a lookup table based | |
// on the least significant eight bits of k to provide a solution to | |
// (x^2 + x)n + k ≡ 0 (mod 2^8). | |
// | |
// Let's take stock. At this point we have an efficient division-free algorithm | |
// for finding solutions to (x^2 + x)n + k ≡ 0 (mod 2^32), for any integer n | |
// congruent to 1 modulo 8. We still need to explain how to use such a solution | |
// x to find a square root of n. | |
// | |
// Once we have an integer x satisfying | |
// | |
// (5) (x^2 + x)n + k ≡ 0 (mod 2^32), | |
// | |
// put b = nx + 2k. Then b^2 + b - k = ((x^2 + x)n + k)n, and so (5) | |
// implies that | |
// | |
// (6) b^2 + b ≡ k (mod 2^32). | |
// | |
// There are exactly two solutions b modulo 2^32 to the congruence (6), the | |
// other being -1-b. So by replacing b with -1-b if necessary, and reducing | |
// modulo 2^32, we can also assume that | |
// | |
// (7) 0 <= b < 2^31 | |
// | |
// Now set a = 2b + 1, then (6) and (7) imply | |
// | |
// (8) a^2 ≡ n (mod 2^34) | |
// | |
// and | |
// | |
// (9) 0 < a < 2^32 | |
// | |
// respectively. But if we know that n is an odd perfect square smaller than | |
// 2^64, (8) and (9) together are enough to guarantee that a is the square root | |
// of n. To see this, note that both a and √n are odd integers in the interval | |
// (0, 2^32), and that condition (8) implies that (a - √n)(a + √n) is divisible | |
// by 2^34. If both a - √n and a + √n are divisible by 4, then their sum 2a is | |
// also divisible by 4, so a is divisible by 2, in contradiction to a being | |
// odd. So one of the two factors is not divisible by 4, and the other must | |
// then be divisible by 2^33. But 0 < a + √n < 2^33, and so a + √n cannot be | |
// divisible by 2^33. So a - √n is divisible by 2^33, and from the bounds on a | |
// and √n it follows that a = √n. | |
// Implementation notes: | |
// | |
// * The above theory applies for odd n. For the general case, we first handle | |
// the special case of an input of 0. Now if n is a perfect square, it has an | |
// even number `j` of trailing zeros in its binary representation. By | |
// shifting the trailing zeros out we get an odd perfect square. Applying the | |
// theory above gives us a square root of that odd perfect square, and | |
// shifting that square root left by `j/2` bits gives the square root of the | |
// original `n`. | |
// * We rely on C's rules for uint32_t arithmetic (namely reduction modulo | |
// 2^32) of intermediate results. These reductions of course do not affect | |
// the validity of the congruences modulo 2^32. | |
// * We make significant use of the fact that working | |
// modulo 2^32, for a value `x` with type `uint32_t`, `~x` is `-1-x`. | |
// Thus our iteration (1) can be written in C as | |
// | |
// x += (n * x * ~x - k) * (x - ~x); | |
// | |
// * The line `b ^= -(b >> 31);` near the end of the function ensures that b < | |
// 2^31: it leaves b unchanged if it's smaller than 2^31, but replaces it | |
// with 2^32 - 1 - b (the other solution to the congruence (6)) if not. | |
#include <stdint.h> | |
#include <stdio.h> | |
static const uint8_t lut[128] = { | |
0, 85, 83, 102, 71, 2, 36, 126, 15, 37, 28, 22, 87, 50, 107, 46, | |
31, 10, 115, 57, 103, 98, 4, 33, 47, 58, 3, 118, 119, 109, 116, 113, | |
63, 106, 108, 38, 120, 61, 27, 62, 79, 101, 35, 41, 104, 13, 84, 17, | |
95, 53, 76, 121, 88, 34, 59, 97, 111, 5, 67, 54, 72, 82, 52, 78, | |
127, 42, 44, 25, 56, 125, 91, 1, 112, 90, 99, 105, 40, 77, 20, 81, | |
96, 117, 12, 70, 24, 29, 123, 94, 80, 69, 124, 9, 8, 18, 11, 14, | |
64, 21, 19, 89, 7, 66, 100, 65, 48, 26, 92, 86, 23, 114, 43, 110, | |
32, 74, 51, 6, 39, 93, 68, 30, 16, 122, 60, 73, 55, 45, 75, 49, | |
}; | |
uint32_t isqrt64_exact(uint64_t n) | |
{ | |
uint32_t m, k, x, b; | |
if (n == 0) | |
return 0; | |
int j = __builtin_ctzl(n); | |
n >>= j; | |
m = (uint32_t)n; | |
k = (uint32_t)(n >> 2); | |
x = lut[k >> 1 & 127]; | |
x += (m * x * ~x - k) * (x - ~x); | |
x += (m * x * ~x - k) * (x - ~x); | |
b = m * x + 2 * k; | |
b ^= -(b >> 31); | |
return (b - ~b) << (j >> 1); | |
} | |
// Do exhaustive testing on all square numbers smaller than 2**64. | |
int main(void) | |
{ | |
uint32_t a = 0; | |
do | |
{ | |
uint64_t n = (uint64_t)a * a; | |
uint32_t b = isqrt64_exact(n); | |
if (a != b) | |
{ | |
printf("Failure for a=%u, b=%u, n=%llu\n", a, b, n); | |
return 1; | |
} | |
a++; | |
} while (a); | |
printf("Success!\n"); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment