Last active
October 20, 2016 02:10
-
-
Save cwfitzgerald/3dd901efbdde151e3ba1738e8c1d0fda to your computer and use it in GitHub Desktop.
A power function using polynomials
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cinttypes> | |
#include <cmath> | |
#include <limits> | |
// Various magic numbers for functions | |
namespace magic_numbers { | |
// Coefficients for polynomial to approximate log(x) | |
constexpr static float logc0 = -2.4204054330123117482f; | |
constexpr static float logc1 = 5.8848619015611924602f; | |
constexpr static float logc2 = -7.4051206397067798695f; | |
constexpr static float logc3 = 6.8875017302144077082f; | |
constexpr static float logc4 = -4.3148398949978662797f; | |
constexpr static float logc5 = 1.7263161508290053959f; | |
constexpr static float logc6 = -3.9884945965699334907e-1f; | |
constexpr static float logc7 = 4.0535645552263665381e-2f; | |
// Coefficients for polynomial to approximate 2 ^ x | |
constexpr static float pow2c0 = 9.999999895522431340595307249206601173160e-1f; | |
constexpr static float pow2c1 = 6.931471741231077374964550442946479169896e-1f; | |
constexpr static float pow2c2 = 2.402268410587209655053240869450290503856e-1f; | |
constexpr static float pow2c3 = 5.550417941535986729965741215786385074913e-2f; | |
constexpr static float pow2c4 = 9.616460628061844000348512769113294090944e-3f; | |
constexpr static float pow2c5 = 1.333137403472538461063687233613364614430e-3f; | |
constexpr static float pow2c6 = 1.566982822625258739107007062218641048893e-4f; | |
constexpr static float pow2c7 = 1.550905805985674282385631387486764679087e-5f; | |
// Masks to get various parts of a floating point number | |
constexpr static uint32_t sign_bit_mask = 0x80000000; // 32 | |
constexpr static uint32_t exponent_mask = 0x7f800000; // 24-31 (val - 127) | |
constexpr static uint32_t mantissa_mask = 0x007fffff; // 1-23 | |
constexpr static uint32_t sign_bit_not_mask = ~sign_bit_mask; | |
constexpr static uint32_t exponent_not_mask = ~exponent_mask; | |
constexpr static uint32_t mantissa_not_mask = ~mantissa_mask; | |
// Floating point exponent of 0 | |
constexpr static uint32_t zero_exponent = 0x3f800000; | |
// Precomputed mathematical functions | |
constexpr static float sqrt_2 = 1.4142135623730950488016887242097f; // sqrt(2) | |
constexpr static float log_sqrt_2 = 0.34657359027997265470861606072909f; // log(sqrt(2)) | |
constexpr static float recip_sqrt_2 = 0.70710678118654752440084436210485f; // 1 / sqrt(2) | |
constexpr static float log_2 = 0.69314718055994530941723212145818f; // log(2) | |
constexpr static float log2_e = 1.442695040888963407359924681001892f; // log2(e) | |
constexpr static float recip_log_2 = log2_e; // 1 / log(2) | |
constexpr static float recip_log_10 = 0.43429448190325182765112891891661f; // 1 / log(10) | |
// Floating point constants | |
constexpr static float float_nan = std::numeric_limits<float>::quiet_NaN(); | |
constexpr static float float_inf = std::numeric_limits<float>::infinity(); | |
} | |
using namespace magic_numbers; | |
// Union for type punning | |
union pun32 { | |
pun32(float fi) : f(fi){}; | |
pun32(uint32_t ii) : i(ii){}; | |
pun32(int32_t sii) : si(sii){}; | |
float f; | |
uint32_t i; | |
int32_t si; | |
}; | |
// Type punning helper functions | |
inline uint32_t as_uint(float x) { | |
return pun32(x).i; | |
} | |
inline int32_t as_int(float x) { | |
return pun32(x).si; | |
} | |
inline float as_float(int32_t x) { | |
return pun32(x).f; | |
} | |
inline float as_float(uint32_t x) { | |
return pun32(x).f; | |
} | |
// Utility functions for floating point | |
// Checks if number is an integer | |
inline bool is_integer(float x) { | |
return std::trunc(x) == x; | |
} | |
// Checks if number is an odd integer | |
inline bool is_odd_integer(float x) { | |
return is_integer(x) && !is_integer(x * 0.5f); | |
} | |
// Check if two numbers are identical | |
// Behavior: | |
// -inf != inf | |
// -0 != 0 | |
// -Nan == Nan | |
inline bool is_identical(float a, float b) { | |
return (a == a && b == b) ? (as_int(a) == as_int(b)) : (a != a && b != b); | |
} | |
// https://stackoverflow.com/questions/10732034/how-are-logarithms-programmed | |
// m = mantissa | |
// p = exponent | |
// log(x) = log(m)+p*log(2) | |
inline float clog(float val) { | |
#ifndef __FAST_MATH__ | |
if (val < 0) { | |
return float_nan; | |
} | |
if (val == 0.0) { | |
return -float_inf; | |
} | |
if (val == float_inf || val != val) { | |
return val; | |
} | |
#endif | |
// Reinterpret the float as an int | |
uint32_t val_i = as_uint(val); | |
// Get the mantissa and exponent | |
int32_t orig_mantissa = val_i & mantissa_mask; | |
int32_t orig_exponent = val_i & exponent_mask; | |
// Add a zero exponent to the mantissa and convert back to float | |
float m = as_float(uint32_t(orig_mantissa) ^ zero_exponent); | |
// Shift exponent to the right and get actual value of the exponent | |
float p = static_cast<float>((orig_exponent >> 23) - 127); | |
// sqrt(2) is logarithmically between [1, 2) therefore | |
// if sqrt2 < m < 2: log(x / sqrt_2) + log(sqrt(2)) == log(x) | |
bool flip_flag = false; | |
if (m >= sqrt_2) { | |
flip_flag = true; | |
m = m * recip_sqrt_2; | |
} | |
// Calculate the logarithm of the mantissa using a Chebyshev polynomial in horner form | |
float logm; | |
logm = logc6 + m * logc7; | |
logm = logc5 + m * logm; | |
logm = logc4 + m * logm; | |
logm = logc3 + m * logm; | |
logm = logc2 + m * logm; | |
logm = logc1 + m * logm; | |
logm = logc0 + m * logm; | |
// Convert back from a the range reduced version | |
if (flip_flag) { | |
logm += log_sqrt_2; | |
} | |
// Return the logarithm | |
return logm + p * log_2; | |
} | |
inline float clog2(float val) { | |
return clog(val) * recip_log_2; | |
} | |
inline float clog10(float val) { | |
return clog(val) * recip_log_10; | |
} | |
// https://stackoverflow.com/questions/15350856/which-method-to-implement-exp-function-in-c | |
// e^x = 2^y iff y = (x*log2(e)) | |
inline float cexp(float val) { | |
#ifndef __FAST_MATH__ | |
if (val == -float_inf) { | |
return 0.0f; | |
} | |
if (val == float_inf || val != val) { | |
return val; | |
} | |
#endif | |
float y = val * log2_e; | |
// Separate integer and decimal parts | |
int32_t int_part = static_cast<int32_t>(std::trunc(y)); | |
float dec_part = y - static_cast<float>(int_part); | |
// Check for overflow | |
if (int_part <= -126 || 127 <= int_part) { | |
return float_inf; | |
} | |
// Put integer part into exponent of floating point number | |
float ret = as_float(((int_part + 127) << 23) & (sign_bit_not_mask)); | |
// Find 2 ^ dec_part | |
float pow2d; | |
pow2d = pow2c6 + dec_part * pow2c7; | |
pow2d = pow2c5 + dec_part * pow2d; | |
pow2d = pow2c4 + dec_part * pow2d; | |
pow2d = pow2c3 + dec_part * pow2d; | |
pow2d = pow2c2 + dec_part * pow2d; | |
pow2d = pow2c1 + dec_part * pow2d; | |
pow2d = pow2c0 + dec_part * pow2d; | |
return ret * pow2d; | |
} | |
// x ^ y == e ^ (y * log(x)) | |
inline float cpow(float base, float exp) { | |
if (base == 1.0f || exp == 0.0f) { | |
return 1.0; | |
} | |
#ifndef __FAST_MATH__ | |
if (is_identical(base, -0.0f) && is_odd_integer(exp)) { | |
if (exp < 0.0f) { | |
return -float_inf; | |
} | |
else { | |
return -0.0f; | |
} | |
} | |
if (exp == float_inf || exp == -float_inf) { | |
if (base == -1.0f) { | |
return 1.0f; | |
} | |
} | |
if (base == -float_inf) { | |
if (is_odd_integer(exp)) { | |
if (exp < 0.0f) { | |
return -0.0f; | |
} | |
else { | |
return -float_inf; | |
} | |
} | |
else { | |
if (exp < 0.0f) { | |
return 0.0f; | |
} | |
else { | |
return float_inf; | |
} | |
} | |
} | |
#endif | |
return cexp(exp * clog(base)); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "pow.hpp" | |
#include <algorithm> | |
#include <chrono> | |
#include <cinttypes> | |
#include <iomanip> | |
#include <iostream> | |
#include <numeric> | |
#include <tuple> | |
#include <utility> | |
#include <vector> | |
long double mantissa_compare(float a, float b) { | |
long double mantissa_a = as_float((as_int(a) & mantissa_mask) ^ zero_exponent); | |
long double mantissa_b = as_float((as_int(b) & mantissa_mask) ^ zero_exponent); | |
long double mantissa_diff = std::abs(mantissa_a - mantissa_b); | |
// The exponent is either exact or off by one, so this will keep things in a reasonable bounds | |
if (mantissa_diff > 0.5L) { | |
mantissa_diff = 1.0L - mantissa_diff; | |
} | |
return mantissa_diff; | |
} | |
int16_t exponent_compare(float a, float b) { | |
return static_cast<int16_t>( | |
std::abs(static_cast<int>(((as_int(a) & exponent_mask) >> 23) - ((as_int(b) & exponent_mask) >> 23)))); | |
} | |
void test_performance(size_t size) { | |
std::vector<float> input(size, 2.1412f); | |
std::vector<float> output(size, 0.0f); | |
using clock = std::chrono::high_resolution_clock; | |
using namespace std::chrono; | |
auto print_time = [](auto start, auto end, const char* name) { | |
auto time = duration_cast<microseconds>(end - start); | |
std::cout << name << ": " << static_cast<float>(time.count()) / 1'000.0f << "ms" << std::endl; | |
}; | |
////////////// | |
// std::log // | |
////////////// | |
auto start = clock::now(); | |
for (size_t i = 0; i < input.size(); ++i) { | |
output[i] = std::log(input[i]); | |
} | |
auto end = clock::now(); | |
print_time(start, end, "std::log"); | |
////////// | |
// clog // | |
////////// | |
start = clock::now(); | |
for (size_t i = 0; i < input.size(); ++i) { | |
output[i] = clog(input[i]); | |
} | |
end = clock::now(); | |
print_time(start, end, " clog"); | |
////////////// | |
// std::exp // | |
////////////// | |
start = clock::now(); | |
for (size_t i = 0; i < input.size(); ++i) { | |
output[i] = std::exp(input[i]); | |
} | |
end = clock::now(); | |
print_time(start, end, "std::exp"); | |
////////// | |
// cexp // | |
////////// | |
start = clock::now(); | |
for (size_t i = 0; i < input.size(); ++i) { | |
output[i] = cexp(input[i]); | |
} | |
end = clock::now(); | |
print_time(start, end, " cexp"); | |
////////////// | |
// std::pow // | |
////////////// | |
start = clock::now(); | |
for (size_t i = 0; i < input.size(); ++i) { | |
output[i] = std::pow(15.0f, input[i]); | |
} | |
end = clock::now(); | |
print_time(start, end, "std::pow"); | |
////////// | |
// cpow // | |
////////// | |
start = clock::now(); | |
for (size_t i = 0; i < input.size(); ++i) { | |
output[i] = cpow(15.0f, input[i]); | |
} | |
end = clock::now(); | |
print_time(start, end, " cpow"); | |
} | |
void test_accuracy(size_t count) { | |
std::vector<float> input(count, 0.0f); | |
std::vector<float> my_out(count, 0.0f); | |
std::vector<float> std_out(count, 0.0f); | |
std::vector<long double> man_diff(count, 0.0L); | |
std::vector<int16_t> exp_diff(count, 0); | |
// Lambda to find error numbers (my function vs the standard function) | |
// for the current data and print them out | |
auto find_error = [&](const char* name) { | |
for (size_t i = 0; i < count; ++i) { | |
man_diff[i] = mantissa_compare(std_out[i], my_out[i]); | |
} | |
for (size_t i = 0; i < count; ++i) { | |
exp_diff[i] = exponent_compare(std_out[i], my_out[i]); | |
} | |
// Find minimum and maximum mantissa error (I wanted to use structured bindings...but oh well) | |
std::vector<long double>::iterator man_min_it, man_max_it; | |
std::tie(man_min_it, man_max_it) = std::minmax_element(man_diff.begin(), man_diff.end()); | |
size_t man_min_index = std::distance(man_diff.begin(), man_min_it); | |
size_t man_max_index = std::distance(man_diff.begin(), man_max_it); | |
long double man_min = *man_min_it; | |
long double man_max = *man_max_it; | |
// 99th percentile mantissa error using nth element on copy of array | |
size_t man_99th_pt_index = static_cast<size_t>(double(count) * 0.99); | |
auto man_tmp = man_diff; | |
std::nth_element(man_tmp.begin(), man_tmp.begin() + man_99th_pt_index, man_tmp.end()); | |
auto man_99th_pt = man_tmp[man_99th_pt_index]; | |
// 1st percentile mantissa error using nth element on copy of array | |
size_t man_01st_pt_index = static_cast<size_t>(double(count) * 0.01); | |
man_tmp = man_diff; | |
std::nth_element(man_tmp.begin(), man_tmp.begin() + man_01st_pt_index, man_tmp.end()); | |
auto man_01st_pt = man_tmp[man_01st_pt_index]; | |
// Average mantissa error | |
long double man_avg = std::accumulate(man_diff.begin(), man_diff.end(), 0.0L) / static_cast<long double>(count); | |
// Check and log all incorrect exponents | |
std::vector<std::pair<size_t, int16_t>> incorrect_exps; | |
size_t exp_count = 0; | |
for (size_t i = 0; i < count; ++i) { | |
exp_count += !exp_diff[i]; | |
if (exp_diff[i] != 0) { | |
incorrect_exps.emplace_back(i, exp_diff[i]); | |
} | |
} | |
// Print function name | |
std::cout << "Accuracy of " << name << ": \n"; | |
// Print count of exponents correct | |
std::cout << "Exponents correct: " << exp_count << '\n'; | |
// Print if there are incorrect exponents | |
if (incorrect_exps.size()) { | |
std::cout << "Exponents incorrect: " << count - exp_count << '\n'; | |
// If there's a reasonable amount of errors, print them out their index and the exponent difference, | |
// and the inputs that corrispond with the errors | |
if (incorrect_exps.size() < 10) { | |
std::cout << "Indexes incorrect: \n"; | |
for (auto&& val : incorrect_exps) { | |
std::cout << "#" << std::setw(static_cast<int>(std::log10(count))) << val.first << ": ±" | |
<< std::setw(4) << val.second << std::setprecision(9) << ": " << name << "(" | |
<< input[val.first] << ") == std: " << std_out[val.first] << " == c" << name << ": " | |
<< my_out[val.first] << "\n"; | |
} | |
std::cout << '\n'; | |
} | |
} | |
// Print out stats about mantissa error | |
std::cout << "Mantissa error max: " << man_max << " at position " << man_max_index << " with value " | |
<< input[man_max_index] << '\n'; | |
std::cout << "Mantissa error 99th: " << man_99th_pt << '\n'; | |
std::cout << "Mantissa error avg: " << man_avg << '\n'; | |
std::cout << "Mantissa error 01st: " << man_01st_pt << '\n'; | |
std::cout << "Mantissa error min: " << man_min << " at position " << man_min_index << " with value " | |
<< input[man_min_index] << '\n'; | |
std::cout << std::endl; | |
}; | |
/////////////////// | |
// std::log/clog // | |
/////////////////// | |
float cur_val = 0; | |
float step = 100'000.0f / static_cast<float>(count); | |
std::generate(input.begin(), input.end(), [&] { return cur_val += step; }); | |
for (size_t i = 0; i < count; ++i) { | |
std_out[i] = std::log(input[i]); | |
} | |
for (size_t i = 0; i < count; ++i) { | |
my_out[i] = clog(input[i]); | |
} | |
find_error("log"); | |
/////////////////// | |
// std::exp/cexp // | |
/////////////////// | |
cur_val = 0; | |
step = 70.0f / static_cast<float>(count); | |
std::generate(input.begin(), input.end(), [&] { return cur_val += step; }); | |
for (size_t i = 0; i < count; ++i) { | |
std_out[i] = std::exp(input[i]); | |
} | |
for (size_t i = 0; i < count; ++i) { | |
my_out[i] = cexp(input[i]); | |
} | |
find_error("exp"); | |
/////////////////// | |
// std::pow/cpow // | |
/////////////////// | |
cur_val = 0; | |
step = 70.0f / static_cast<float>(count); | |
std::generate(input.begin(), input.end(), [&] { return cur_val += step; }); | |
for (size_t i = 0; i < count; ++i) { | |
std_out[i] = std::pow(2.5f, input[i]); | |
} | |
for (size_t i = 0; i < count; ++i) { | |
my_out[i] = cpow(2.5, input[i]); | |
} | |
find_error("pow"); | |
} | |
// Function to print out test data | |
void print_test_data(size_t num, float glib, float my, float eq) { | |
std::cout << std::boolalpha << std::setw(2) << num << ": " << std::setw(5); | |
std::cout << is_identical(my, eq) << " | " << std::setw(5) << is_identical(glib, eq); | |
std::cout << ": " << std::setw(5) << my << " | " << std::setw(5) << glib << " == " << eq << '\n'; | |
} | |
// Functions to collect test data and forward it | |
void test_log(size_t num, float val, float eq) { | |
auto glib = std::log(val); | |
auto my = clog(val); | |
print_test_data(num, glib, my, eq); | |
} | |
void test_exp(size_t num, float exp, float eq) { | |
auto glib = std::exp(exp); | |
auto my = cexp(exp); | |
print_test_data(num, glib, my, eq); | |
} | |
void test_pow(size_t num, float base, float exp, float eq) { | |
auto glib = std::pow(base, exp); | |
auto my = cpow(base, exp); | |
print_test_data(num, glib, my, eq); | |
} | |
int main(int argc, char** argv) { | |
(void) argc; | |
(void) argv; | |
std::cout << std::setprecision(10); | |
std::cout << "Testing log IEEE compliance: \n" | |
<< " clog std clog std expected\n"; | |
test_log(1, 0.0f, -float_inf); // 1 | |
test_log(1, -0.0f, -float_inf); // 1 | |
test_log(2, 1.0f, 0.0f); // 2 | |
test_log(3, -1.0f, float_nan); // 3 | |
test_log(4, float_inf, float_inf); // 4 | |
test_log(5, float_nan, float_nan); // 5 | |
std::cout << '\n'; | |
std::cout << "Testing exp IEEE compliance: \n" | |
<< " cexp std cexp std expected\n"; | |
test_exp(1, 0.0f, 1); // 1 | |
test_exp(1, -0.0f, 1); // 1 | |
test_exp(2, -float_inf, 0); // 2 | |
test_exp(3, float_inf, float_inf); // 3 | |
test_exp(4, float_nan, float_nan); // 3 | |
std::cout << '\n'; | |
std::cout << "Testing pow IEEE compliance: \n" | |
<< " cpow std cpow std expected\n"; | |
test_pow(1, 0.0f, -3.0f, float_inf); // 1 | |
test_pow(2, -0.0f, -3.0f, -float_inf); // 2 | |
test_pow(3, 0.0f, -2.0f, float_inf); // 3 | |
test_pow(3, -0.0f, -2.0f, float_inf); // 3 | |
test_pow(4, 0.0f, -float_inf, float_inf); // 4 | |
test_pow(4, -0.0f, -float_inf, float_inf); // 4 | |
test_pow(5, 0.0f, 3.0f, 0.0f); // 5 | |
test_pow(6, -0.0f, 3.0f, -0.0f); // 6 | |
test_pow(7, 0.0f, 2.0f, 0.0f); // 7 | |
test_pow(7, -0.0f, 2.0f, 0.0f); // 7 | |
test_pow(8, -1.0f, float_inf, 1.0f); // 8 | |
test_pow(8, -1.0f, -float_inf, 1.0f); // 8 | |
test_pow(9, 1.0f, 3.0f, 1.0f); // 9 | |
test_pow(9, 1.0f, float_nan, 1.0f); // 9 | |
test_pow(10, 2.0f, 0.0f, 1.0f); // 10 | |
test_pow(10, 2.0f, -0.0f, 1.0f); // 10 | |
test_pow(10, float_nan, 0.0f, 1.0f); // 10 | |
test_pow(10, float_nan, -0.0f, 1.0f); // 10 | |
test_pow(11, -2.0f, 3.4f, float_nan); // 11 | |
test_pow(12, 0.5f, -float_inf, float_inf); // 12 | |
test_pow(13, 1.5f, -float_inf, 0.0f); // 13 | |
test_pow(14, 0.5f, float_inf, 0.0f); // 14 | |
test_pow(15, 1.5f, float_inf, float_inf); // 15 | |
test_pow(16, -float_inf, -3.0f, -0.0f); // 16 | |
test_pow(17, -float_inf, -2.0f, 0.0f); // 17 | |
test_pow(18, -float_inf, 3.0f, -float_inf); // 18 | |
test_pow(19, -float_inf, 2.0f, float_inf); // 19 | |
test_pow(20, float_inf, -3.0f, 0.0f); // 20 | |
test_pow(21, float_inf, 3.0f, float_inf); // 21 | |
std::cout << std::endl; | |
// Make precision not as rediculous for printing times | |
std::cout << std::setprecision(4); | |
std::cout << "Testing speed on 1,000,000 floats\n"; | |
test_performance(1'000'000); | |
std::cout << std::setprecision(10); | |
std::cout << "\nTesting accuracy on 1'000'000 floats\n"; | |
test_accuracy(1'000'000); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment