Fast inverse square root explanation
#include <cmath>
#include <stdint.h>
#include <stdio.h>
// | idx | mae | description |
// |-----|----------|-------------------|
// | 0 | 1.008427 | Naiive with k0=0 |
// | 1 | 0.144398 | Original Quake |
// | 2 | 0.099314 | Our grad descent |
// | 3 | 0.060105 | Jan Kadlec (wiki) |
// | 4 | 0.039234 | Our grad descent |
#define IMPL_VERSION 4
float quick_invsqrt(float x) {
union {
int32_t log_val;
float f32_val;
} x_pun;
x_pun.f32_val = x;
// IEEE-754-1985 format gives the following bit expression for single precision floating point format
// - number of bits: K: 0, N: 8, M: 23
// - interpreted as f32
// x = (-1)^K * 2^(N-127) * (1 + M*2^-23)
// Quantisation as integer through punning gives
// - Q = K*2^31 + N*2^23 + M
// Assume positive value for x, so K = 0 since y=1/sqrt(x) must have x > 0
// - Q = N*2^23 + M
// Consider the expansion of log2(x) in terms of IEEE-754
// x = 2^(N-127) * (1+M*2^-23)
// - let n = N-127
// m = M*2^-23
// x = 2^n * (1+m)
// log2(x) = log2[2^n * (1+m)]
// log2(x) = log2(2^n) + log2(1+m)
// log2(x) = n + log2(1+m)
// Since m = M*2^-23
// - 0 <= M <= 2^23-1
// - 0 <= m < 1
// - 1 <= 1+m < 2
// - 0 <= log2(1+m) < 1
// We can approximate log2(1+m) ~ m + k0 (Graphing it's shows it's a pretty good linear approximation)
// - k0 is a free parameter
// - k0 = 0 is the solution to fit the endpoints log2(1)=0, log2(2)=1
// - k0 can be varied to be non-zero to minimize total error across the approximation
// log2(1+m) ~ m + k0
// Substituting approximation for log2(1+m)
// log2(x) = n + log2(1+m)
// log2(x) ~ n + m + k0
// - Substitute n and m
// log2(x) ~ N-127 + M*2^-23 + k0
// log2(x)*2^23 ~ N*2^23 - 127*2^23 + k0*2^23 + M
// - Substitute quantisation Q = N*2^23 + M
// log2(x)*2^23 ~ Q - 127*2^23 + k0*2^23
// Q ~ 2^23 * [ log2(x) + 127 - k0 ]
// Deriving expression for y = 1/sqrt(x) in terms of log2
// log2(y) = log2(x^-0.5)
// log2(y) = -0.5*log2(x)
// Using the approximation of our quantisation Q
// Qy ~ 2^23 * [ log2(y) + 127 - k0 ]
// Qy ~ 2^23 * [-0.5*log2(x) + 127 - k0 ]
// Qx ~ 2^23 * [ log2(x) + 127 - k0 ]
// -0.5*Qx ~ 2^23 * [ -0.5*log2(x) - 63.5 + 0.5*k0 ]
// -0.5*Qx ~ 2^23 * [ -0.5*log2(x) + 127 - k0 - 190.5 + 1.5*k0]
// -0.5*Qx ~ 2^23 * [ -0.5*log2(x) + 127 - k0 ] - 2^23 * [ 190.5 - 1.5*k0 ]
// Subsituting Qy
// -0.5*Qx ~ Qy - 2^23 * [190.5 - 1.5*k0]
// Qy ~ 2^23*(190.5 - 1.5*k0) - 0.5*Qx
// Let k1 ~ 2^23*(190.5 - 1.5*k0) be our new free parameter
// Qy ~ k1 - 0.5*Qx
// If k0=0 then k1 = 2^23 * 190.5 = 2^22 * 381
constexpr static int32_t LOG2_CONST = 381 << 22; // Naiive version
#elif IMPL_VERSION == 1
constexpr static int32_t LOG2_CONST = 0x5F37'59DF; // Original Quake implementation
#elif IMPL_VERSION == 2
constexpr static int32_t LOG2_CONST = 0x5F35093D; // Our gradient descent version
x_pun.log_val = LOG2_CONST - (x_pun.log_val >> 1); // Qy ~ k2 - 0.5*Qx
float y_approx = x_pun.f32_val;
// Consider f(y) = 1/y^2 - x' = (1 - x'*y^2)/y^2
// - x' = initial value of x
// Solving for f(y) = 0, gives y = 1/sqrt(x')
// Use Newton's method for finding roots, given an initial close guess for y
// f'(y) = = -2*y^-3 = -2/y^3
// y1 = y0 - f(y0)/f'(y0) | y1 = successive approximation
// - f(y)/f'(y) = f_delta * dy/df = y_delta (y_delta, f_delta relative to closest root)
// - f(y)/f'(y) = -0.5*y*[1-x'*y^2]
// y1 = y0 + 0.5*y0*[1 - x'y0^2]
// y1 = y0 + 0.5*y0 - 0.5*x'*y0^3
// y1 = 1.5*y0 - 0.5*x'*y0^3
// y1 = y0*[1.5 - 0.5*x'*y0^2]
const float x_half = x*0.5f;
constexpr static size_t TOTAL_ITERATIONS = 1;
for (size_t i = 0; i < TOTAL_ITERATIONS; i++) {
y_approx = y_approx*(1.5f - x_half*y_approx*y_approx);
return y_approx;
#elif IMPL_VERSION == 3
// Source:
// Jan Kadlec's version
float quick_invsqrt(float x) {
union {
int32_t log_val;
float f32_val;
} x_pun;
x_pun.f32_val = x;
x_pun.log_val = 0x5F1F'FFF9 - (x_pun.log_val >> 1);
float y = x_pun.f32_val;
y *= 0.703952253f * (2.38924456f - x*y*y);
return y;
#elif IMPL_VERSION == 4
// Our version with three free parameters based on Jan Kadlec's reformulation
float quick_invsqrt(float x) {
// Qx = quant(x)
// Qy ~ k0 - 0.5*Qx
// y0 = dequant(Qy)
// y1 = y0*(k1*x*y0^2 + k2)
constexpr int32_t k0 = 0x5EDA97E8;
constexpr float k1 = -2.13202330;
constexpr float k2 = 2.43318741;
union {
int32_t log_val;
float f32_val;
} x_pun;
x_pun.f32_val = x;
x_pun.log_val = k0 - (x_pun.log_val >> 1);
float y = x_pun.f32_val;
y = y*(k1*x*y*y + k2);
return y;
// clang++ main.cpp -o main -O3
int main(int argc, char** argv) {
float mean_absolute_error = 0.0f;
size_t total_samples = 0;
const auto push_sample = [&](float x) {
const float y_target = 1.0f/std::sqrt(x);
const float y_pred = quick_invsqrt(x);
const float error = std::abs(y_target-y_pred);
mean_absolute_error += error;
for (float x = 1e-7; x <= 1e-6; x += 1e-9) push_sample(x);
for (float x = 1e-6; x <= 1e-5; x += 1e-8) push_sample(x);
for (float x = 1e-5; x <= 1e-4; x += 1e-7) push_sample(x);
for (float x = 1e-4; x <= 1e-3; x += 1e-6) push_sample(x);
for (float x = 1e-3; x <= 1e-2; x += 1e-5) push_sample(x);
for (float x = 1e-2; x <= 1e-1; x += 1e-4) push_sample(x);
for (float x = 1e-1; x <= 1e0; x += 1e-3) push_sample(x);
for (float x = 1e0; x <= 1e1; x += 1e-2) push_sample(x);
for (float x = 1e1; x <= 1e2; x += 1e-1) push_sample(x);
for (float x = 1e2; x <= 1e3; x += 1e0) push_sample(x);
for (float x = 1e3; x <= 1e4; x += 1e1) push_sample(x);
for (float x = 1e4; x <= 1e5; x += 1e2) push_sample(x);
for (float x = 1e5; x <= 1e6; x += 1e3) push_sample(x);
for (float x = 1e6; x <= 1e7; x += 1e4) push_sample(x);
for (float x = 1e7; x <= 1e8; x += 1e5) push_sample(x);
mean_absolute_error /= float(total_samples);
printf("mae=%.6f\n", mean_absolute_error);
printf("total_samples=%zu\n", total_samples);
return 0;
clang++ benchmark_fast_inverse_square_root.cpp -o benchmark_fast_inverse_square_root.exe -O3 -std=c++17
clang++ derive_fast_inverse_square_root_magic_number.cpp -o derive_fast_inverse_square_root_magic_number.exe -O3 -std=c++17
clang++ derive_fast_inverse_square_root_three_params.cpp -o derive_fast_inverse_square_root_three_params.exe -O3 -std=c++17
#include <cmath>
#include <stdint.h>
#include <stdio.h>
#include <inttypes.h>
#include <vector>
static inline int32_t quantise_float(float x) {
union {
float f32;
int32_t i32;
} y;
y.f32 = x;
return y.i32;
// clang++ main.cpp -o main -std=c++17 -O3 -march=native
int main(int argc, char** argv) {
// Qy ~ 2^23*(190.5 - 1.5*k0) - 0.5*Qx
// 2*Qy ~ 2^23*381 - 2^23*3*k0 - Qx
// Jy ~ 2^23*381 - 2^23*3*k0 - Qx
// Jy ~ 2^23*381 - k1 - Qx
constexpr int64_t M0 = int64_t(381) << 23;
std::vector<int64_t> QX;
std::vector<int64_t> JY_target;
const auto push_sample = [&](float x) {
const float y = 1.0/std::sqrt(x);
const int32_t Qx = quantise_float(x);
const int32_t Qy = quantise_float(y);
const int64_t Jy = int64_t(Qy)*2;
for (float x = 1e-7; x <= 1e-6; x += 1e-9) push_sample(x);
for (float x = 1e-6; x <= 1e-5; x += 1e-8) push_sample(x);
for (float x = 1e-5; x <= 1e-4; x += 1e-7) push_sample(x);
for (float x = 1e-4; x <= 1e-3; x += 1e-6) push_sample(x);
for (float x = 1e-3; x <= 1e-2; x += 1e-5) push_sample(x);
for (float x = 1e-2; x <= 1e-1; x += 1e-4) push_sample(x);
for (float x = 1e-1; x <= 1e0; x += 1e-3) push_sample(x);
for (float x = 1e0; x <= 1e1; x += 1e-2) push_sample(x);
for (float x = 1e1; x <= 1e2; x += 1e-1) push_sample(x);
for (float x = 1e2; x <= 1e3; x += 1e0) push_sample(x);
for (float x = 1e3; x <= 1e4; x += 1e1) push_sample(x);
for (float x = 1e4; x <= 1e5; x += 1e2) push_sample(x);
for (float x = 1e5; x <= 1e6; x += 1e3) push_sample(x);
for (float x = 1e6; x <= 1e7; x += 1e4) push_sample(x);
for (float x = 1e7; x <= 1e8; x += 1e5) push_sample(x);
const size_t N = JY_target.size();
printf("training over %zu samples\n", N);
int64_t k1 = 0;
int64_t best_k1 = k1;
uint64_t best_mse = ~uint64_t(0);
size_t best_iter = 0;
constexpr size_t TOTAL_ITERATIONS = 1024;
constexpr size_t PRINT_ITER = TOTAL_ITERATIONS / 32;
for (size_t iter = 0; iter < TOTAL_ITERATIONS; iter++) {
int64_t mean_error = 0;
uint64_t mean_square_error = 0;
for (size_t i = 0; i < N; i++) {
// Jy ~ 2^23*381 - k1 - Qx
// error = 0.5*(Jy - Jy')^2
// error = 0.5*(2^23*381 - k1 - Qx - Qy')^2
// de/dk = -(Qy-Qy')
const int64_t Qx = QX[i];
const int64_t Jy_target = JY_target[i];
const int64_t Jy_pred = M0 - k1 - Qx;
const int64_t error = Jy_pred - Jy_target;
mean_error += error;
mean_square_error += uint64_t(error*error);
mean_error /= int64_t(N);
mean_square_error /= int64_t(N);
if (mean_square_error < best_mse) {
best_mse = mean_square_error;
best_k1 = k1;
best_iter = iter;
int64_t gradient = mean_error;
k1 += gradient;
if (iter % PRINT_ITER == 0 || mean_error == 0) {
printf("[%.3zu] mse=%" PRIu64 ", mean_error=%" PRIi64 "\n", iter, mean_square_error, mean_error);
if (mean_error == 0) {
printf("\n[BEST RESULTS]\n");
printf("iter=%zu\n", best_iter);
printf("mse=%" PRIu64 "\n", best_mse);
printf("k1=%" PRIi64 " (0x%08X)\n", best_k1, int32_t(best_k1));
// k1 = 2^23 * 3 * k0
const double k0 = double(best_k1)/double(int64_t(3) << 23);
printf("k0=%.8f\n", k0);
// k2 = (2^23*381 - k1) / 2
const int32_t k2 = int32_t((M0 - best_k1)/2);
printf("k2=%" PRIi32 " (0x%08X)\n", k2, k2);
// Qy = k2 - 0.5*Qx
return 0;
#include <cmath>
#include <inttypes.h>
#include <random>
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include <thread>
#include <mutex>
static bool volatile is_running = true;
#if _WIN32
#define NOMINMAX
#include <windows.h>
BOOL WINAPI sighandler(DWORD signum) {
if (signum == CTRL_C_EVENT) {
fprintf(stderr, "Signal caught, exiting!\n");
is_running = false;
return TRUE;
return FALSE;
#include <errno.h>
#include <signal.h>
static void sighandler(int signum) {
fprintf(stderr, "Signal caught, exiting! (%d)\n", signum);
is_running = false;
static inline int32_t quantise_float(float x) {
union {
float f32;
int32_t i32;
} y;
y.f32 = x;
return y.i32;
static inline float dequantise_i32(int32_t x) {
union {
float f32;
int32_t i32;
} y;
y.i32 = x;
return y.f32;
// clang++ main.cpp -o main -std=c++17 -O3 -march=native
int main(int argc, char** argv) {
#if _WIN32
SetConsoleCtrlHandler(sighandler, TRUE);
struct sigaction sigact;
sigact.sa_handler = sighandler;
sigact.sa_flags = 0;
sigaction(SIGINT, &sigact, nullptr);
sigaction(SIGTERM, &sigact, nullptr);
sigaction(SIGQUIT, &sigact, nullptr);
sigaction(SIGPIPE, &sigact, nullptr);
std::vector<float> X_in;
std::vector<float> Y_target;
std::vector<int32_t> Qy_target;
const auto push_sample = [&](float x) {
const float y = 1.0/std::sqrt(x);
for (float x = 1e-7; x <= 1e-6; x += 1e-9) push_sample(x);
for (float x = 1e-6; x <= 1e-5; x += 1e-8) push_sample(x);
for (float x = 1e-5; x <= 1e-4; x += 1e-7) push_sample(x);
for (float x = 1e-4; x <= 1e-3; x += 1e-6) push_sample(x);
for (float x = 1e-3; x <= 1e-2; x += 1e-5) push_sample(x);
for (float x = 1e-2; x <= 1e-1; x += 1e-4) push_sample(x);
for (float x = 1e-1; x <= 1e0; x += 1e-3) push_sample(x);
for (float x = 1e0; x <= 1e1; x += 1e-2) push_sample(x);
for (float x = 1e1; x <= 1e2; x += 1e-1) push_sample(x);
for (float x = 1e2; x <= 1e3; x += 1e0) push_sample(x);
for (float x = 1e3; x <= 1e4; x += 1e1) push_sample(x);
for (float x = 1e4; x <= 1e5; x += 1e2) push_sample(x);
for (float x = 1e5; x <= 1e6; x += 1e3) push_sample(x);
for (float x = 1e6; x <= 1e7; x += 1e4) push_sample(x);
for (float x = 1e7; x <= 1e8; x += 1e5) push_sample(x);
const size_t N = X_in.size();
printf("training over %zu samples\n", N);
// Solve for the following approximation
// Qy ~ 2^22*381 - k0 - 0.5*Qx, k0' = 2^22*381 - k0
// y0 = dequant(Qy)
// y1 = y0*(k1*x*y0^2 + k2) | Custom Newton's method
constexpr int64_t M0 = int64_t(381) << 22;
// const double M1 = std::log(double(2.0))/double(int32_t(3) << 22);
const double M1 = 1.0;
struct Params {
int64_t k0 = 0;
double k1 = 0.0;
double k2 = 0.0;
std::random_device rng_dev;
std::mt19937 rng_gen(rng_dev());
std::uniform_real_distribution<double> rng_f32(0.0, 1.0);
const auto gen_rand_params = [&]() -> Params {
Params rand_params;
rand_params.k0 = int64_t(rng_f32(rng_gen)*1e8);
rand_params.k1 = rng_f32(rng_gen)*-5.0;
rand_params.k2 = rng_f32(rng_gen)*5.0;
// Based on Jan Kadlec's version
// rand_params.k0 = M0 - 0x5F1F'FFF9;
// rand_params.k1 = -0.703952253;
// rand_params.k2 = 0.703952253 * 2.38924456;
return rand_params;
struct Result {
Params params;
double mae = double(~uint64_t(0));
size_t version = 0;
size_t iter = 0;
size_t thread_id = 0;
Result best_result;
std::mutex best_result_mutex;
constexpr size_t PRINT_ITER = 20'000;
constexpr size_t MAXIMUM_PLATEAU_RESTART = 1'000;
std::vector<std::thread> threads;
const size_t TOTAL_THREADS = std::thread::hardware_concurrency();
// const size_t TOTAL_THREADS = 1;
for (size_t thread_id = 0; thread_id < TOTAL_THREADS; thread_id++) {
threads.push_back(std::thread([&, thread_id]() {
Result best_thread_result;
Result best_ver_result;
size_t total_plateau = 0;
size_t curr_version = 0;
size_t curr_iter = 0;
Params curr_params;
curr_params = gen_rand_params();
while (is_running) {
double mean_absolute_error = 0;
double avg_de_dk[3] = {0};
for (size_t i = 0; i < N; i++) {
const float x_f32 = X_in[i];
const float y_f32 = Y_target[i];
// Qx ~ 2^23 * [log2(x) + 127 - k0']
const int32_t Qx = quantise_float(x_f32);
// Qy ~ 2^23*(190.5 - 1.5*k0') - 0.5*Qx
// Qy ~ 2^22*(381 - 3*k0') - 0.5*Qx
// Qy ~ 2^22*381 - k0 - 0.5*Qx, where k0 = 2^22*3*k0'
const int64_t Qy = M0 - curr_params.k0 - (Qx >> 1);
// Qy ~ 2^23 * [log2(y) + 127 - k0']
// log2(y) ~ Qy*2^-23 - 127 + k0'
// y = f(Qy) ~ 2^[ Qy*2^-23 - 127 + k0/(3*2^22) ]
const double y0 = double(dequantise_i32(int32_t(Qy)));
// y1 = g(y0) = y0*(k1*x*y0^2 + k2)
const double x_f64 = double(x_f32);
const double y1 = y0*(curr_params.k1*x_f64*y0*y0 + curr_params.k2);
// e = 0.5*(y_target - y1)^2
// e = 0.5*(y_target - g(y0))^2
// de/dg = -(y_target - g(y0))
const double y_f64 = double(y_f32);
const double de_dg = y_f64 - y1;
// g(y0) = y0*(k1*x*y0^2 + k2)
// dg/dk1 = x*y0^3
// dg/dk2 = y0
const double dg_dk1 = x_f64*y0*y0*y0;
const double dg_dk2 = y0;
// g(y0) = g(f(Qy)), y0 = f(Qy)
// g(Qy) = f(Qy)*(k1*x*f(Qy)^2 + k2)
// dg/df = 3*k1*x*f(Qy)^2 + k2
const double dg_df = 3.0*curr_params.k1*x_f64*y0*y0 + curr_params.k2;
// f(Qy) ~ 2^[ Qy*2^-23 - 127 + k0/(3*2^22) ]
// df/dk0 ~ ln(2)/(3*2^22) * f(Qy)
const double df_dk0 = M1 * y0;
// chain rule
const double de_dk2 = de_dg*dg_dk2;
const double de_dk1 = de_dg*dg_dk1;
const double de_dk0 = de_dg*dg_df*df_dk0;
avg_de_dk[0] += de_dk0;
avg_de_dk[1] += de_dk1;
avg_de_dk[2] += de_dk2;
// e1 = 0.5*(Qy_target - Qy)^2
// e1 = 0.5*(Qy_target - M0 - k0 - Qx/2)^2
// de1/dk0 = -(Qy_target-Qy)
const double de1_dk0 = double(int64_t(Qy_target[i]) - Qy);
avg_de_dk[0] -= 1e-1*de1_dk0;
mean_absolute_error += std::abs(de_dg);
mean_absolute_error /= double(N);
if (mean_absolute_error < best_ver_result.mae) {
best_ver_result.mae = mean_absolute_error;
best_ver_result.version = curr_version;
best_ver_result.iter = curr_iter;
best_ver_result.params = curr_params;
best_ver_result.thread_id = thread_id;
total_plateau = 0;
} else {
if (best_ver_result.mae < best_thread_result.mae) {
best_thread_result = best_ver_result;
bool is_reset = false;
if (total_plateau >= MAXIMUM_PLATEAU_RESTART) {
is_reset = true;
constexpr static double scale_de_dk[3] = {1e-3, 1e-6, 1e-6};
for (size_t i = 0; i < 3; i++) {
const double scale = scale_de_dk[i]/double(N);
avg_de_dk[i] *= scale;
curr_params.k0 += int64_t(avg_de_dk[0]);
curr_params.k1 += avg_de_dk[1];
curr_params.k2 += avg_de_dk[2];
if (curr_iter % PRINT_ITER == 0 || is_reset) {
"[%02zu|%zu|%.3zu] mae=%.3e, de_dk[3]={%.2e,%.2e,%.2e}\n",
thread_id, curr_version, curr_iter,
avg_de_dk[0], avg_de_dk[1], avg_de_dk[2]
if (is_reset) {
total_plateau = 0;
curr_iter = 0;
curr_params = gen_rand_params();
best_ver_result = Result{};
auto lock = std::unique_lock<std::mutex>(best_result_mutex);
if (best_thread_result.mae < best_result.mae) {
best_result = best_thread_result;
for (auto& thread: threads) {
printf("\n[BEST RESULT]\n");
printf("thread=%zu\n", best_result.thread_id);
printf("version=%zu\n", best_result.version);
printf("iter=%zu\n", best_result.iter);
printf("mae=%.6f\n", best_result.mae);
// k0_ = 2^22*381 - k0
const int64_t k0_ = M0 - best_result.params.k0;
printf("k0=%" PRIi64 " (0x%08X)\n", k0_, int32_t(k0_));
printf("k1=%.8f\n", best_result.params.k1);
printf("k2=%.8f\n", best_result.params.k2);
return 0;
