Skip to content

Instantly share code, notes, and snippets.

@williamyang98
Last active June 5, 2024 18:27
Show Gist options
  • Save williamyang98/6766b7a8bc6bbecf7f7ece8fd6792d2a to your computer and use it in GitHub Desktop.
Save williamyang98/6766b7a8bc6bbecf7f7ece8fd6792d2a to your computer and use it in GitHub Desktop.
Fast inverse square root explanation
#include <cmath>
#include <stdint.h>
#include <stdio.h>
// | idx | mae | description |
// |-----|----------|-------------------|
// | 0 | 1.008427 | Naiive with k0=0 |
// | 1 | 0.144398 | Original Quake |
// | 2 | 0.099314 | Our grad descent |
// | 3 | 0.060105 | Jan Kadlec (wiki) |
// | 4 | 0.039234 | Our grad descent |
#define IMPL_VERSION 4
#if IMPL_VERSION < 3
float quick_invsqrt(float x) {
union {
int32_t log_val;
float f32_val;
} x_pun;
x_pun.f32_val = x;
// IEEE-754-1985 format gives the following bit expression for single precision floating point format
// - number of bits: K: 0, N: 8, M: 23
// - interpreted as f32
// x = (-1)^K * 2^(N-127) * (1 + M*2^-23)
//
// Quantisation as integer through punning gives
// - Q = K*2^31 + N*2^23 + M
// Assume positive value for x, so K = 0 since y=1/sqrt(x) must have x > 0
// - Q = N*2^23 + M
//
// Consider the expansion of log2(x) in terms of IEEE-754
// x = 2^(N-127) * (1+M*2^-23)
// - let n = N-127
// m = M*2^-23
// x = 2^n * (1+m)
// log2(x) = log2[2^n * (1+m)]
// log2(x) = log2(2^n) + log2(1+m)
// log2(x) = n + log2(1+m)
//
// Since m = M*2^-23
// - 0 <= M <= 2^23-1
// - 0 <= m < 1
// - 1 <= 1+m < 2
// - 0 <= log2(1+m) < 1
//
// We can approximate log2(1+m) ~ m + k0 (Graphing it's shows it's a pretty good linear approximation)
// - k0 is a free parameter
// - k0 = 0 is the solution to fit the endpoints log2(1)=0, log2(2)=1
// - k0 can be varied to be non-zero to minimize total error across the approximation
// log2(1+m) ~ m + k0
//
// Substituting approximation for log2(1+m)
// log2(x) = n + log2(1+m)
// log2(x) ~ n + m + k0
// - Substitute n and m
// log2(x) ~ N-127 + M*2^-23 + k0
// log2(x)*2^23 ~ N*2^23 - 127*2^23 + k0*2^23 + M
// - Substitute quantisation Q = N*2^23 + M
// log2(x)*2^23 ~ Q - 127*2^23 + k0*2^23
// Q ~ 2^23 * [ log2(x) + 127 - k0 ]
// Deriving expression for y = 1/sqrt(x) in terms of log2
// log2(y) = log2(x^-0.5)
// log2(y) = -0.5*log2(x)
//
// Using the approximation of our quantisation Q
// Qy ~ 2^23 * [ log2(y) + 127 - k0 ]
// Qy ~ 2^23 * [-0.5*log2(x) + 127 - k0 ]
//
// Qx ~ 2^23 * [ log2(x) + 127 - k0 ]
// -0.5*Qx ~ 2^23 * [ -0.5*log2(x) - 63.5 + 0.5*k0 ]
// -0.5*Qx ~ 2^23 * [ -0.5*log2(x) + 127 - k0 - 190.5 + 1.5*k0]
// -0.5*Qx ~ 2^23 * [ -0.5*log2(x) + 127 - k0 ] - 2^23 * [ 190.5 - 1.5*k0 ]
//
// Subsituting Qy
// -0.5*Qx ~ Qy - 2^23 * [190.5 - 1.5*k0]
// Qy ~ 2^23*(190.5 - 1.5*k0) - 0.5*Qx
//
// Let k1 ~ 2^23*(190.5 - 1.5*k0) be our new free parameter
// Qy ~ k1 - 0.5*Qx
#if IMPL_VERSION == 0
// If k0=0 then k1 = 2^23 * 190.5 = 2^22 * 381
constexpr static int32_t LOG2_CONST = 381 << 22; // Naiive version
#elif IMPL_VERSION == 1
constexpr static int32_t LOG2_CONST = 0x5F37'59DF; // Original Quake implementation
#elif IMPL_VERSION == 2
constexpr static int32_t LOG2_CONST = 0x5F35093D; // Our gradient descent version
#endif
x_pun.log_val = LOG2_CONST - (x_pun.log_val >> 1); // Qy ~ k2 - 0.5*Qx
float y_approx = x_pun.f32_val;
// Consider f(y) = 1/y^2 - x' = (1 - x'*y^2)/y^2
// - x' = initial value of x
// Solving for f(y) = 0, gives y = 1/sqrt(x')
// Use Newton's method for finding roots, given an initial close guess for y
// f'(y) = = -2*y^-3 = -2/y^3
// y1 = y0 - f(y0)/f'(y0) | y1 = successive approximation
// - f(y)/f'(y) = f_delta * dy/df = y_delta (y_delta, f_delta relative to closest root)
// - f(y)/f'(y) = -0.5*y*[1-x'*y^2]
// y1 = y0 + 0.5*y0*[1 - x'y0^2]
// y1 = y0 + 0.5*y0 - 0.5*x'*y0^3
// y1 = 1.5*y0 - 0.5*x'*y0^3
// y1 = y0*[1.5 - 0.5*x'*y0^2]
const float x_half = x*0.5f;
constexpr static size_t TOTAL_ITERATIONS = 1;
for (size_t i = 0; i < TOTAL_ITERATIONS; i++) {
y_approx = y_approx*(1.5f - x_half*y_approx*y_approx);
}
return y_approx;
}
#elif IMPL_VERSION == 3
// Source: https://en.wikipedia.org/wiki/Fast_inverse_square_root#Subsequent_improvements
// Jan Kadlec's version
float quick_invsqrt(float x) {
union {
int32_t log_val;
float f32_val;
} x_pun;
x_pun.f32_val = x;
x_pun.log_val = 0x5F1F'FFF9 - (x_pun.log_val >> 1);
float y = x_pun.f32_val;
y *= 0.703952253f * (2.38924456f - x*y*y);
return y;
}
#elif IMPL_VERSION == 4
// Our version with three free parameters based on Jan Kadlec's reformulation
float quick_invsqrt(float x) {
// Qx = quant(x)
// Qy ~ k0 - 0.5*Qx
// y0 = dequant(Qy)
// y1 = y0*(k1*x*y0^2 + k2)
constexpr int32_t k0 = 0x5EDA97E8;
constexpr float k1 = -2.13202330;
constexpr float k2 = 2.43318741;
union {
int32_t log_val;
float f32_val;
} x_pun;
x_pun.f32_val = x;
x_pun.log_val = k0 - (x_pun.log_val >> 1);
float y = x_pun.f32_val;
y = y*(k1*x*y*y + k2);
return y;
}
#endif
// clang++ main.cpp -o main -O3
int main(int argc, char** argv) {
float mean_absolute_error = 0.0f;
size_t total_samples = 0;
const auto push_sample = [&](float x) {
const float y_target = 1.0f/std::sqrt(x);
const float y_pred = quick_invsqrt(x);
const float error = std::abs(y_target-y_pred);
mean_absolute_error += error;
total_samples++;
};
for (float x = 1e-7; x <= 1e-6; x += 1e-9) push_sample(x);
for (float x = 1e-6; x <= 1e-5; x += 1e-8) push_sample(x);
for (float x = 1e-5; x <= 1e-4; x += 1e-7) push_sample(x);
for (float x = 1e-4; x <= 1e-3; x += 1e-6) push_sample(x);
for (float x = 1e-3; x <= 1e-2; x += 1e-5) push_sample(x);
for (float x = 1e-2; x <= 1e-1; x += 1e-4) push_sample(x);
for (float x = 1e-1; x <= 1e0; x += 1e-3) push_sample(x);
for (float x = 1e0; x <= 1e1; x += 1e-2) push_sample(x);
for (float x = 1e1; x <= 1e2; x += 1e-1) push_sample(x);
for (float x = 1e2; x <= 1e3; x += 1e0) push_sample(x);
for (float x = 1e3; x <= 1e4; x += 1e1) push_sample(x);
for (float x = 1e4; x <= 1e5; x += 1e2) push_sample(x);
for (float x = 1e5; x <= 1e6; x += 1e3) push_sample(x);
for (float x = 1e6; x <= 1e7; x += 1e4) push_sample(x);
for (float x = 1e7; x <= 1e8; x += 1e5) push_sample(x);
mean_absolute_error /= float(total_samples);
printf("mae=%.6f\n", mean_absolute_error);
printf("total_samples=%zu\n", total_samples);
return 0;
}
clang++ benchmark_fast_inverse_square_root.cpp -o benchmark_fast_inverse_square_root.exe -O3 -std=c++17
clang++ derive_fast_inverse_square_root_magic_number.cpp -o derive_fast_inverse_square_root_magic_number.exe -O3 -std=c++17
clang++ derive_fast_inverse_square_root_three_params.cpp -o derive_fast_inverse_square_root_three_params.exe -O3 -std=c++17
#include <cmath>
#include <stdint.h>
#include <stdio.h>
#include <inttypes.h>
#include <vector>
static inline int32_t quantise_float(float x) {
union {
float f32;
int32_t i32;
} y;
y.f32 = x;
return y.i32;
}
// clang++ main.cpp -o main -std=c++17 -O3 -march=native
int main(int argc, char** argv) {
// Qy ~ 2^23*(190.5 - 1.5*k0) - 0.5*Qx
// 2*Qy ~ 2^23*381 - 2^23*3*k0 - Qx
// Jy ~ 2^23*381 - 2^23*3*k0 - Qx
// Jy ~ 2^23*381 - k1 - Qx
constexpr int64_t M0 = int64_t(381) << 23;
std::vector<int64_t> QX;
std::vector<int64_t> JY_target;
const auto push_sample = [&](float x) {
const float y = 1.0/std::sqrt(x);
const int32_t Qx = quantise_float(x);
const int32_t Qy = quantise_float(y);
const int64_t Jy = int64_t(Qy)*2;
QX.push_back(int64_t(Qx));
JY_target.push_back(Jy);
};
for (float x = 1e-7; x <= 1e-6; x += 1e-9) push_sample(x);
for (float x = 1e-6; x <= 1e-5; x += 1e-8) push_sample(x);
for (float x = 1e-5; x <= 1e-4; x += 1e-7) push_sample(x);
for (float x = 1e-4; x <= 1e-3; x += 1e-6) push_sample(x);
for (float x = 1e-3; x <= 1e-2; x += 1e-5) push_sample(x);
for (float x = 1e-2; x <= 1e-1; x += 1e-4) push_sample(x);
for (float x = 1e-1; x <= 1e0; x += 1e-3) push_sample(x);
for (float x = 1e0; x <= 1e1; x += 1e-2) push_sample(x);
for (float x = 1e1; x <= 1e2; x += 1e-1) push_sample(x);
for (float x = 1e2; x <= 1e3; x += 1e0) push_sample(x);
for (float x = 1e3; x <= 1e4; x += 1e1) push_sample(x);
for (float x = 1e4; x <= 1e5; x += 1e2) push_sample(x);
for (float x = 1e5; x <= 1e6; x += 1e3) push_sample(x);
for (float x = 1e6; x <= 1e7; x += 1e4) push_sample(x);
for (float x = 1e7; x <= 1e8; x += 1e5) push_sample(x);
const size_t N = JY_target.size();
printf("training over %zu samples\n", N);
int64_t k1 = 0;
int64_t best_k1 = k1;
uint64_t best_mse = ~uint64_t(0);
size_t best_iter = 0;
constexpr size_t TOTAL_ITERATIONS = 1024;
constexpr size_t PRINT_ITER = TOTAL_ITERATIONS / 32;
for (size_t iter = 0; iter < TOTAL_ITERATIONS; iter++) {
int64_t mean_error = 0;
uint64_t mean_square_error = 0;
for (size_t i = 0; i < N; i++) {
// Jy ~ 2^23*381 - k1 - Qx
// error = 0.5*(Jy - Jy')^2
// error = 0.5*(2^23*381 - k1 - Qx - Qy')^2
// de/dk = -(Qy-Qy')
const int64_t Qx = QX[i];
const int64_t Jy_target = JY_target[i];
const int64_t Jy_pred = M0 - k1 - Qx;
const int64_t error = Jy_pred - Jy_target;
mean_error += error;
mean_square_error += uint64_t(error*error);
}
mean_error /= int64_t(N);
mean_square_error /= int64_t(N);
if (mean_square_error < best_mse) {
best_mse = mean_square_error;
best_k1 = k1;
best_iter = iter;
}
int64_t gradient = mean_error;
k1 += gradient;
if (iter % PRINT_ITER == 0 || mean_error == 0) {
printf("[%.3zu] mse=%" PRIu64 ", mean_error=%" PRIi64 "\n", iter, mean_square_error, mean_error);
}
if (mean_error == 0) {
break;
}
}
printf("\n[BEST RESULTS]\n");
printf("iter=%zu\n", best_iter);
printf("mse=%" PRIu64 "\n", best_mse);
printf("k1=%" PRIi64 " (0x%08X)\n", best_k1, int32_t(best_k1));
// k1 = 2^23 * 3 * k0
const double k0 = double(best_k1)/double(int64_t(3) << 23);
printf("k0=%.8f\n", k0);
// k2 = (2^23*381 - k1) / 2
const int32_t k2 = int32_t((M0 - best_k1)/2);
printf("k2=%" PRIi32 " (0x%08X)\n", k2, k2);
// Qy = k2 - 0.5*Qx
return 0;
}
#include <cmath>
#include <inttypes.h>
#include <random>
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include <thread>
#include <mutex>
#define USE_SIGNAL_HANDLER 1
#if USE_SIGNAL_HANDLER
static bool volatile is_running = true;
#if _WIN32
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
BOOL WINAPI sighandler(DWORD signum) {
if (signum == CTRL_C_EVENT) {
fprintf(stderr, "Signal caught, exiting!\n");
is_running = false;
return TRUE;
}
return FALSE;
}
#else
#include <errno.h>
#include <signal.h>
static void sighandler(int signum) {
fprintf(stderr, "Signal caught, exiting! (%d)\n", signum);
is_running = false;
}
#endif
#endif
static inline int32_t quantise_float(float x) {
union {
float f32;
int32_t i32;
} y;
y.f32 = x;
return y.i32;
}
static inline float dequantise_i32(int32_t x) {
union {
float f32;
int32_t i32;
} y;
y.i32 = x;
return y.f32;
}
// clang++ main.cpp -o main -std=c++17 -O3 -march=native
int main(int argc, char** argv) {
#if USE_SIGNAL_HANDLER
#if _WIN32
SetConsoleCtrlHandler(sighandler, TRUE);
#else
struct sigaction sigact;
sigact.sa_handler = sighandler;
sigemptyset(&sigact.sa_mask);
sigact.sa_flags = 0;
sigaction(SIGINT, &sigact, nullptr);
sigaction(SIGTERM, &sigact, nullptr);
sigaction(SIGQUIT, &sigact, nullptr);
sigaction(SIGPIPE, &sigact, nullptr);
#endif
#endif
std::vector<float> X_in;
std::vector<float> Y_target;
std::vector<int32_t> Qy_target;
const auto push_sample = [&](float x) {
const float y = 1.0/std::sqrt(x);
X_in.push_back(x);
Y_target.push_back(y);
Qy_target.push_back(quantise_float(y));
};
for (float x = 1e-7; x <= 1e-6; x += 1e-9) push_sample(x);
for (float x = 1e-6; x <= 1e-5; x += 1e-8) push_sample(x);
for (float x = 1e-5; x <= 1e-4; x += 1e-7) push_sample(x);
for (float x = 1e-4; x <= 1e-3; x += 1e-6) push_sample(x);
for (float x = 1e-3; x <= 1e-2; x += 1e-5) push_sample(x);
for (float x = 1e-2; x <= 1e-1; x += 1e-4) push_sample(x);
for (float x = 1e-1; x <= 1e0; x += 1e-3) push_sample(x);
for (float x = 1e0; x <= 1e1; x += 1e-2) push_sample(x);
for (float x = 1e1; x <= 1e2; x += 1e-1) push_sample(x);
for (float x = 1e2; x <= 1e3; x += 1e0) push_sample(x);
for (float x = 1e3; x <= 1e4; x += 1e1) push_sample(x);
for (float x = 1e4; x <= 1e5; x += 1e2) push_sample(x);
for (float x = 1e5; x <= 1e6; x += 1e3) push_sample(x);
for (float x = 1e6; x <= 1e7; x += 1e4) push_sample(x);
for (float x = 1e7; x <= 1e8; x += 1e5) push_sample(x);
const size_t N = X_in.size();
printf("training over %zu samples\n", N);
// Solve for the following approximation
// Qy ~ 2^22*381 - k0 - 0.5*Qx, k0' = 2^22*381 - k0
// y0 = dequant(Qy)
// y1 = y0*(k1*x*y0^2 + k2) | Custom Newton's method
constexpr int64_t M0 = int64_t(381) << 22;
// const double M1 = std::log(double(2.0))/double(int32_t(3) << 22);
const double M1 = 1.0;
struct Params {
int64_t k0 = 0;
double k1 = 0.0;
double k2 = 0.0;
};
std::random_device rng_dev;
std::mt19937 rng_gen(rng_dev());
std::uniform_real_distribution<double> rng_f32(0.0, 1.0);
const auto gen_rand_params = [&]() -> Params {
Params rand_params;
rand_params.k0 = int64_t(rng_f32(rng_gen)*1e8);
rand_params.k1 = rng_f32(rng_gen)*-5.0;
rand_params.k2 = rng_f32(rng_gen)*5.0;
// Based on Jan Kadlec's version
// rand_params.k0 = M0 - 0x5F1F'FFF9;
// rand_params.k1 = -0.703952253;
// rand_params.k2 = 0.703952253 * 2.38924456;
return rand_params;
};
struct Result {
Params params;
double mae = double(~uint64_t(0));
size_t version = 0;
size_t iter = 0;
size_t thread_id = 0;
};
Result best_result;
std::mutex best_result_mutex;
constexpr size_t PRINT_ITER = 20'000;
constexpr size_t MAXIMUM_PLATEAU_RESTART = 1'000;
std::vector<std::thread> threads;
const size_t TOTAL_THREADS = std::thread::hardware_concurrency();
// const size_t TOTAL_THREADS = 1;
for (size_t thread_id = 0; thread_id < TOTAL_THREADS; thread_id++) {
threads.push_back(std::thread([&, thread_id]() {
Result best_thread_result;
Result best_ver_result;
size_t total_plateau = 0;
size_t curr_version = 0;
size_t curr_iter = 0;
Params curr_params;
curr_params = gen_rand_params();
while (is_running) {
double mean_absolute_error = 0;
double avg_de_dk[3] = {0};
for (size_t i = 0; i < N; i++) {
const float x_f32 = X_in[i];
const float y_f32 = Y_target[i];
// Qx ~ 2^23 * [log2(x) + 127 - k0']
const int32_t Qx = quantise_float(x_f32);
// Qy ~ 2^23*(190.5 - 1.5*k0') - 0.5*Qx
// Qy ~ 2^22*(381 - 3*k0') - 0.5*Qx
// Qy ~ 2^22*381 - k0 - 0.5*Qx, where k0 = 2^22*3*k0'
const int64_t Qy = M0 - curr_params.k0 - (Qx >> 1);
// Qy ~ 2^23 * [log2(y) + 127 - k0']
// log2(y) ~ Qy*2^-23 - 127 + k0'
// y = f(Qy) ~ 2^[ Qy*2^-23 - 127 + k0/(3*2^22) ]
const double y0 = double(dequantise_i32(int32_t(Qy)));
// y1 = g(y0) = y0*(k1*x*y0^2 + k2)
const double x_f64 = double(x_f32);
const double y1 = y0*(curr_params.k1*x_f64*y0*y0 + curr_params.k2);
// e = 0.5*(y_target - y1)^2
// e = 0.5*(y_target - g(y0))^2
// de/dg = -(y_target - g(y0))
const double y_f64 = double(y_f32);
const double de_dg = y_f64 - y1;
// g(y0) = y0*(k1*x*y0^2 + k2)
// dg/dk1 = x*y0^3
// dg/dk2 = y0
const double dg_dk1 = x_f64*y0*y0*y0;
const double dg_dk2 = y0;
// g(y0) = g(f(Qy)), y0 = f(Qy)
// g(Qy) = f(Qy)*(k1*x*f(Qy)^2 + k2)
// dg/df = 3*k1*x*f(Qy)^2 + k2
const double dg_df = 3.0*curr_params.k1*x_f64*y0*y0 + curr_params.k2;
// f(Qy) ~ 2^[ Qy*2^-23 - 127 + k0/(3*2^22) ]
// df/dk0 ~ ln(2)/(3*2^22) * f(Qy)
const double df_dk0 = M1 * y0;
// chain rule
const double de_dk2 = de_dg*dg_dk2;
const double de_dk1 = de_dg*dg_dk1;
const double de_dk0 = de_dg*dg_df*df_dk0;
avg_de_dk[0] += de_dk0;
avg_de_dk[1] += de_dk1;
avg_de_dk[2] += de_dk2;
// e1 = 0.5*(Qy_target - Qy)^2
// e1 = 0.5*(Qy_target - M0 - k0 - Qx/2)^2
// de1/dk0 = -(Qy_target-Qy)
const double de1_dk0 = double(int64_t(Qy_target[i]) - Qy);
avg_de_dk[0] -= 1e-1*de1_dk0;
mean_absolute_error += std::abs(de_dg);
}
mean_absolute_error /= double(N);
if (mean_absolute_error < best_ver_result.mae) {
best_ver_result.mae = mean_absolute_error;
best_ver_result.version = curr_version;
best_ver_result.iter = curr_iter;
best_ver_result.params = curr_params;
best_ver_result.thread_id = thread_id;
total_plateau = 0;
} else {
total_plateau++;
}
if (best_ver_result.mae < best_thread_result.mae) {
best_thread_result = best_ver_result;
}
bool is_reset = false;
if (total_plateau >= MAXIMUM_PLATEAU_RESTART) {
is_reset = true;
}
constexpr static double scale_de_dk[3] = {1e-3, 1e-6, 1e-6};
for (size_t i = 0; i < 3; i++) {
const double scale = scale_de_dk[i]/double(N);
avg_de_dk[i] *= scale;
}
curr_params.k0 += int64_t(avg_de_dk[0]);
curr_params.k1 += avg_de_dk[1];
curr_params.k2 += avg_de_dk[2];
if (curr_iter % PRINT_ITER == 0 || is_reset) {
printf(
"[%02zu|%zu|%.3zu] mae=%.3e, de_dk[3]={%.2e,%.2e,%.2e}\n",
thread_id, curr_version, curr_iter,
mean_absolute_error,
avg_de_dk[0], avg_de_dk[1], avg_de_dk[2]
);
}
if (is_reset) {
total_plateau = 0;
curr_version++;
curr_iter = 0;
curr_params = gen_rand_params();
best_ver_result = Result{};
continue;
}
curr_iter++;
}
{
auto lock = std::unique_lock<std::mutex>(best_result_mutex);
if (best_thread_result.mae < best_result.mae) {
best_result = best_thread_result;
}
}
}));
}
for (auto& thread: threads) {
thread.join();
}
printf("\n[BEST RESULT]\n");
printf("thread=%zu\n", best_result.thread_id);
printf("version=%zu\n", best_result.version);
printf("iter=%zu\n", best_result.iter);
printf("mae=%.6f\n", best_result.mae);
// k0_ = 2^22*381 - k0
const int64_t k0_ = M0 - best_result.params.k0;
printf("k0=%" PRIi64 " (0x%08X)\n", k0_, int32_t(k0_));
printf("k1=%.8f\n", best_result.params.k1);
printf("k2=%.8f\n", best_result.params.k2);
return 0;
}
@williamyang98
Copy link
Author

  • Run build.sh to compile programs against clang.
  • Use the implementation that uses our own derived three parameter solution for lowest mean absolute error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment