Skip to content

Instantly share code, notes, and snippets.

@cwfitzgerald
Last active October 20, 2016 02:10
Show Gist options
  • Save cwfitzgerald/3dd901efbdde151e3ba1738e8c1d0fda to your computer and use it in GitHub Desktop.
Save cwfitzgerald/3dd901efbdde151e3ba1738e8c1d0fda to your computer and use it in GitHub Desktop.
A power function using polynomials
#include <cinttypes>
#include <cmath>
#include <limits>
// Various magic numbers for functions
namespace magic_numbers {
// Coefficients for polynomial to approximate log(x)
constexpr static float logc0 = -2.4204054330123117482f;
constexpr static float logc1 = 5.8848619015611924602f;
constexpr static float logc2 = -7.4051206397067798695f;
constexpr static float logc3 = 6.8875017302144077082f;
constexpr static float logc4 = -4.3148398949978662797f;
constexpr static float logc5 = 1.7263161508290053959f;
constexpr static float logc6 = -3.9884945965699334907e-1f;
constexpr static float logc7 = 4.0535645552263665381e-2f;
// Coefficients for polynomial to approximate 2 ^ x
constexpr static float pow2c0 = 9.999999895522431340595307249206601173160e-1f;
constexpr static float pow2c1 = 6.931471741231077374964550442946479169896e-1f;
constexpr static float pow2c2 = 2.402268410587209655053240869450290503856e-1f;
constexpr static float pow2c3 = 5.550417941535986729965741215786385074913e-2f;
constexpr static float pow2c4 = 9.616460628061844000348512769113294090944e-3f;
constexpr static float pow2c5 = 1.333137403472538461063687233613364614430e-3f;
constexpr static float pow2c6 = 1.566982822625258739107007062218641048893e-4f;
constexpr static float pow2c7 = 1.550905805985674282385631387486764679087e-5f;
// Masks to get various parts of a floating point number
constexpr static uint32_t sign_bit_mask = 0x80000000; // 32
constexpr static uint32_t exponent_mask = 0x7f800000; // 24-31 (val - 127)
constexpr static uint32_t mantissa_mask = 0x007fffff; // 1-23
constexpr static uint32_t sign_bit_not_mask = ~sign_bit_mask;
constexpr static uint32_t exponent_not_mask = ~exponent_mask;
constexpr static uint32_t mantissa_not_mask = ~mantissa_mask;
// Floating point exponent of 0
constexpr static uint32_t zero_exponent = 0x3f800000;
// Precomputed mathematical functions
constexpr static float sqrt_2 = 1.4142135623730950488016887242097f; // sqrt(2)
constexpr static float log_sqrt_2 = 0.34657359027997265470861606072909f; // log(sqrt(2))
constexpr static float recip_sqrt_2 = 0.70710678118654752440084436210485f; // 1 / sqrt(2)
constexpr static float log_2 = 0.69314718055994530941723212145818f; // log(2)
constexpr static float log2_e = 1.442695040888963407359924681001892f; // log2(e)
constexpr static float recip_log_2 = log2_e; // 1 / log(2)
constexpr static float recip_log_10 = 0.43429448190325182765112891891661f; // 1 / log(10)
// Floating point constants
constexpr static float float_nan = std::numeric_limits<float>::quiet_NaN();
constexpr static float float_inf = std::numeric_limits<float>::infinity();
}
using namespace magic_numbers;
// Union for type punning
union pun32 {
pun32(float fi) : f(fi){};
pun32(uint32_t ii) : i(ii){};
pun32(int32_t sii) : si(sii){};
float f;
uint32_t i;
int32_t si;
};
// Type punning helper functions
inline uint32_t as_uint(float x) {
return pun32(x).i;
}
inline int32_t as_int(float x) {
return pun32(x).si;
}
inline float as_float(int32_t x) {
return pun32(x).f;
}
inline float as_float(uint32_t x) {
return pun32(x).f;
}
// Utility functions for floating point
// Checks if number is an integer
inline bool is_integer(float x) {
return std::trunc(x) == x;
}
// Checks if number is an odd integer
inline bool is_odd_integer(float x) {
return is_integer(x) && !is_integer(x * 0.5f);
}
// Check if two numbers are identical
// Behavior:
// -inf != inf
// -0 != 0
// -Nan == Nan
inline bool is_identical(float a, float b) {
return (a == a && b == b) ? (as_int(a) == as_int(b)) : (a != a && b != b);
}
// https://stackoverflow.com/questions/10732034/how-are-logarithms-programmed
// m = mantissa
// p = exponent
// log(x) = log(m)+p*log(2)
inline float clog(float val) {
#ifndef __FAST_MATH__
if (val < 0) {
return float_nan;
}
if (val == 0.0) {
return -float_inf;
}
if (val == float_inf || val != val) {
return val;
}
#endif
// Reinterpret the float as an int
uint32_t val_i = as_uint(val);
// Get the mantissa and exponent
int32_t orig_mantissa = val_i & mantissa_mask;
int32_t orig_exponent = val_i & exponent_mask;
// Add a zero exponent to the mantissa and convert back to float
float m = as_float(uint32_t(orig_mantissa) ^ zero_exponent);
// Shift exponent to the right and get actual value of the exponent
float p = static_cast<float>((orig_exponent >> 23) - 127);
// sqrt(2) is logarithmically between [1, 2) therefore
// if sqrt2 < m < 2: log(x / sqrt_2) + log(sqrt(2)) == log(x)
bool flip_flag = false;
if (m >= sqrt_2) {
flip_flag = true;
m = m * recip_sqrt_2;
}
// Calculate the logarithm of the mantissa using a Chebyshev polynomial in horner form
float logm;
logm = logc6 + m * logc7;
logm = logc5 + m * logm;
logm = logc4 + m * logm;
logm = logc3 + m * logm;
logm = logc2 + m * logm;
logm = logc1 + m * logm;
logm = logc0 + m * logm;
// Convert back from a the range reduced version
if (flip_flag) {
logm += log_sqrt_2;
}
// Return the logarithm
return logm + p * log_2;
}
inline float clog2(float val) {
return clog(val) * recip_log_2;
}
inline float clog10(float val) {
return clog(val) * recip_log_10;
}
// https://stackoverflow.com/questions/15350856/which-method-to-implement-exp-function-in-c
// e^x = 2^y iff y = (x*log2(e))
inline float cexp(float val) {
#ifndef __FAST_MATH__
if (val == -float_inf) {
return 0.0f;
}
if (val == float_inf || val != val) {
return val;
}
#endif
float y = val * log2_e;
// Separate integer and decimal parts
int32_t int_part = static_cast<int32_t>(std::trunc(y));
float dec_part = y - static_cast<float>(int_part);
// Check for overflow
if (int_part <= -126 || 127 <= int_part) {
return float_inf;
}
// Put integer part into exponent of floating point number
float ret = as_float(((int_part + 127) << 23) & (sign_bit_not_mask));
// Find 2 ^ dec_part
float pow2d;
pow2d = pow2c6 + dec_part * pow2c7;
pow2d = pow2c5 + dec_part * pow2d;
pow2d = pow2c4 + dec_part * pow2d;
pow2d = pow2c3 + dec_part * pow2d;
pow2d = pow2c2 + dec_part * pow2d;
pow2d = pow2c1 + dec_part * pow2d;
pow2d = pow2c0 + dec_part * pow2d;
return ret * pow2d;
}
// x ^ y == e ^ (y * log(x))
inline float cpow(float base, float exp) {
if (base == 1.0f || exp == 0.0f) {
return 1.0;
}
#ifndef __FAST_MATH__
if (is_identical(base, -0.0f) && is_odd_integer(exp)) {
if (exp < 0.0f) {
return -float_inf;
}
else {
return -0.0f;
}
}
if (exp == float_inf || exp == -float_inf) {
if (base == -1.0f) {
return 1.0f;
}
}
if (base == -float_inf) {
if (is_odd_integer(exp)) {
if (exp < 0.0f) {
return -0.0f;
}
else {
return -float_inf;
}
}
else {
if (exp < 0.0f) {
return 0.0f;
}
else {
return float_inf;
}
}
}
#endif
return cexp(exp * clog(base));
}
#include "pow.hpp"
#include <algorithm>
#include <chrono>
#include <cinttypes>
#include <iomanip>
#include <iostream>
#include <numeric>
#include <tuple>
#include <utility>
#include <vector>
long double mantissa_compare(float a, float b) {
long double mantissa_a = as_float((as_int(a) & mantissa_mask) ^ zero_exponent);
long double mantissa_b = as_float((as_int(b) & mantissa_mask) ^ zero_exponent);
long double mantissa_diff = std::abs(mantissa_a - mantissa_b);
// The exponent is either exact or off by one, so this will keep things in a reasonable bounds
if (mantissa_diff > 0.5L) {
mantissa_diff = 1.0L - mantissa_diff;
}
return mantissa_diff;
}
int16_t exponent_compare(float a, float b) {
return static_cast<int16_t>(
std::abs(static_cast<int>(((as_int(a) & exponent_mask) >> 23) - ((as_int(b) & exponent_mask) >> 23))));
}
void test_performance(size_t size) {
std::vector<float> input(size, 2.1412f);
std::vector<float> output(size, 0.0f);
using clock = std::chrono::high_resolution_clock;
using namespace std::chrono;
auto print_time = [](auto start, auto end, const char* name) {
auto time = duration_cast<microseconds>(end - start);
std::cout << name << ": " << static_cast<float>(time.count()) / 1'000.0f << "ms" << std::endl;
};
//////////////
// std::log //
//////////////
auto start = clock::now();
for (size_t i = 0; i < input.size(); ++i) {
output[i] = std::log(input[i]);
}
auto end = clock::now();
print_time(start, end, "std::log");
//////////
// clog //
//////////
start = clock::now();
for (size_t i = 0; i < input.size(); ++i) {
output[i] = clog(input[i]);
}
end = clock::now();
print_time(start, end, " clog");
//////////////
// std::exp //
//////////////
start = clock::now();
for (size_t i = 0; i < input.size(); ++i) {
output[i] = std::exp(input[i]);
}
end = clock::now();
print_time(start, end, "std::exp");
//////////
// cexp //
//////////
start = clock::now();
for (size_t i = 0; i < input.size(); ++i) {
output[i] = cexp(input[i]);
}
end = clock::now();
print_time(start, end, " cexp");
//////////////
// std::pow //
//////////////
start = clock::now();
for (size_t i = 0; i < input.size(); ++i) {
output[i] = std::pow(15.0f, input[i]);
}
end = clock::now();
print_time(start, end, "std::pow");
//////////
// cpow //
//////////
start = clock::now();
for (size_t i = 0; i < input.size(); ++i) {
output[i] = cpow(15.0f, input[i]);
}
end = clock::now();
print_time(start, end, " cpow");
}
void test_accuracy(size_t count) {
std::vector<float> input(count, 0.0f);
std::vector<float> my_out(count, 0.0f);
std::vector<float> std_out(count, 0.0f);
std::vector<long double> man_diff(count, 0.0L);
std::vector<int16_t> exp_diff(count, 0);
// Lambda to find error numbers (my function vs the standard function)
// for the current data and print them out
auto find_error = [&](const char* name) {
for (size_t i = 0; i < count; ++i) {
man_diff[i] = mantissa_compare(std_out[i], my_out[i]);
}
for (size_t i = 0; i < count; ++i) {
exp_diff[i] = exponent_compare(std_out[i], my_out[i]);
}
// Find minimum and maximum mantissa error (I wanted to use structured bindings...but oh well)
std::vector<long double>::iterator man_min_it, man_max_it;
std::tie(man_min_it, man_max_it) = std::minmax_element(man_diff.begin(), man_diff.end());
size_t man_min_index = std::distance(man_diff.begin(), man_min_it);
size_t man_max_index = std::distance(man_diff.begin(), man_max_it);
long double man_min = *man_min_it;
long double man_max = *man_max_it;
// 99th percentile mantissa error using nth element on copy of array
size_t man_99th_pt_index = static_cast<size_t>(double(count) * 0.99);
auto man_tmp = man_diff;
std::nth_element(man_tmp.begin(), man_tmp.begin() + man_99th_pt_index, man_tmp.end());
auto man_99th_pt = man_tmp[man_99th_pt_index];
// 1st percentile mantissa error using nth element on copy of array
size_t man_01st_pt_index = static_cast<size_t>(double(count) * 0.01);
man_tmp = man_diff;
std::nth_element(man_tmp.begin(), man_tmp.begin() + man_01st_pt_index, man_tmp.end());
auto man_01st_pt = man_tmp[man_01st_pt_index];
// Average mantissa error
long double man_avg = std::accumulate(man_diff.begin(), man_diff.end(), 0.0L) / static_cast<long double>(count);
// Check and log all incorrect exponents
std::vector<std::pair<size_t, int16_t>> incorrect_exps;
size_t exp_count = 0;
for (size_t i = 0; i < count; ++i) {
exp_count += !exp_diff[i];
if (exp_diff[i] != 0) {
incorrect_exps.emplace_back(i, exp_diff[i]);
}
}
// Print function name
std::cout << "Accuracy of " << name << ": \n";
// Print count of exponents correct
std::cout << "Exponents correct: " << exp_count << '\n';
// Print if there are incorrect exponents
if (incorrect_exps.size()) {
std::cout << "Exponents incorrect: " << count - exp_count << '\n';
// If there's a reasonable amount of errors, print them out their index and the exponent difference,
// and the inputs that corrispond with the errors
if (incorrect_exps.size() < 10) {
std::cout << "Indexes incorrect: \n";
for (auto&& val : incorrect_exps) {
std::cout << "#" << std::setw(static_cast<int>(std::log10(count))) << val.first << ": ±"
<< std::setw(4) << val.second << std::setprecision(9) << ": " << name << "("
<< input[val.first] << ") == std: " << std_out[val.first] << " == c" << name << ": "
<< my_out[val.first] << "\n";
}
std::cout << '\n';
}
}
// Print out stats about mantissa error
std::cout << "Mantissa error max: " << man_max << " at position " << man_max_index << " with value "
<< input[man_max_index] << '\n';
std::cout << "Mantissa error 99th: " << man_99th_pt << '\n';
std::cout << "Mantissa error avg: " << man_avg << '\n';
std::cout << "Mantissa error 01st: " << man_01st_pt << '\n';
std::cout << "Mantissa error min: " << man_min << " at position " << man_min_index << " with value "
<< input[man_min_index] << '\n';
std::cout << std::endl;
};
///////////////////
// std::log/clog //
///////////////////
float cur_val = 0;
float step = 100'000.0f / static_cast<float>(count);
std::generate(input.begin(), input.end(), [&] { return cur_val += step; });
for (size_t i = 0; i < count; ++i) {
std_out[i] = std::log(input[i]);
}
for (size_t i = 0; i < count; ++i) {
my_out[i] = clog(input[i]);
}
find_error("log");
///////////////////
// std::exp/cexp //
///////////////////
cur_val = 0;
step = 70.0f / static_cast<float>(count);
std::generate(input.begin(), input.end(), [&] { return cur_val += step; });
for (size_t i = 0; i < count; ++i) {
std_out[i] = std::exp(input[i]);
}
for (size_t i = 0; i < count; ++i) {
my_out[i] = cexp(input[i]);
}
find_error("exp");
///////////////////
// std::pow/cpow //
///////////////////
cur_val = 0;
step = 70.0f / static_cast<float>(count);
std::generate(input.begin(), input.end(), [&] { return cur_val += step; });
for (size_t i = 0; i < count; ++i) {
std_out[i] = std::pow(2.5f, input[i]);
}
for (size_t i = 0; i < count; ++i) {
my_out[i] = cpow(2.5, input[i]);
}
find_error("pow");
}
// Function to print out test data
void print_test_data(size_t num, float glib, float my, float eq) {
std::cout << std::boolalpha << std::setw(2) << num << ": " << std::setw(5);
std::cout << is_identical(my, eq) << " | " << std::setw(5) << is_identical(glib, eq);
std::cout << ": " << std::setw(5) << my << " | " << std::setw(5) << glib << " == " << eq << '\n';
}
// Functions to collect test data and forward it
void test_log(size_t num, float val, float eq) {
auto glib = std::log(val);
auto my = clog(val);
print_test_data(num, glib, my, eq);
}
void test_exp(size_t num, float exp, float eq) {
auto glib = std::exp(exp);
auto my = cexp(exp);
print_test_data(num, glib, my, eq);
}
void test_pow(size_t num, float base, float exp, float eq) {
auto glib = std::pow(base, exp);
auto my = cpow(base, exp);
print_test_data(num, glib, my, eq);
}
int main(int argc, char** argv) {
(void) argc;
(void) argv;
std::cout << std::setprecision(10);
std::cout << "Testing log IEEE compliance: \n"
<< " clog std clog std expected\n";
test_log(1, 0.0f, -float_inf); // 1
test_log(1, -0.0f, -float_inf); // 1
test_log(2, 1.0f, 0.0f); // 2
test_log(3, -1.0f, float_nan); // 3
test_log(4, float_inf, float_inf); // 4
test_log(5, float_nan, float_nan); // 5
std::cout << '\n';
std::cout << "Testing exp IEEE compliance: \n"
<< " cexp std cexp std expected\n";
test_exp(1, 0.0f, 1); // 1
test_exp(1, -0.0f, 1); // 1
test_exp(2, -float_inf, 0); // 2
test_exp(3, float_inf, float_inf); // 3
test_exp(4, float_nan, float_nan); // 3
std::cout << '\n';
std::cout << "Testing pow IEEE compliance: \n"
<< " cpow std cpow std expected\n";
test_pow(1, 0.0f, -3.0f, float_inf); // 1
test_pow(2, -0.0f, -3.0f, -float_inf); // 2
test_pow(3, 0.0f, -2.0f, float_inf); // 3
test_pow(3, -0.0f, -2.0f, float_inf); // 3
test_pow(4, 0.0f, -float_inf, float_inf); // 4
test_pow(4, -0.0f, -float_inf, float_inf); // 4
test_pow(5, 0.0f, 3.0f, 0.0f); // 5
test_pow(6, -0.0f, 3.0f, -0.0f); // 6
test_pow(7, 0.0f, 2.0f, 0.0f); // 7
test_pow(7, -0.0f, 2.0f, 0.0f); // 7
test_pow(8, -1.0f, float_inf, 1.0f); // 8
test_pow(8, -1.0f, -float_inf, 1.0f); // 8
test_pow(9, 1.0f, 3.0f, 1.0f); // 9
test_pow(9, 1.0f, float_nan, 1.0f); // 9
test_pow(10, 2.0f, 0.0f, 1.0f); // 10
test_pow(10, 2.0f, -0.0f, 1.0f); // 10
test_pow(10, float_nan, 0.0f, 1.0f); // 10
test_pow(10, float_nan, -0.0f, 1.0f); // 10
test_pow(11, -2.0f, 3.4f, float_nan); // 11
test_pow(12, 0.5f, -float_inf, float_inf); // 12
test_pow(13, 1.5f, -float_inf, 0.0f); // 13
test_pow(14, 0.5f, float_inf, 0.0f); // 14
test_pow(15, 1.5f, float_inf, float_inf); // 15
test_pow(16, -float_inf, -3.0f, -0.0f); // 16
test_pow(17, -float_inf, -2.0f, 0.0f); // 17
test_pow(18, -float_inf, 3.0f, -float_inf); // 18
test_pow(19, -float_inf, 2.0f, float_inf); // 19
test_pow(20, float_inf, -3.0f, 0.0f); // 20
test_pow(21, float_inf, 3.0f, float_inf); // 21
std::cout << std::endl;
// Make precision not as rediculous for printing times
std::cout << std::setprecision(4);
std::cout << "Testing speed on 1,000,000 floats\n";
test_performance(1'000'000);
std::cout << std::setprecision(10);
std::cout << "\nTesting accuracy on 1'000'000 floats\n";
test_accuracy(1'000'000);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment