Instantly share code, notes, and snippets.
Created
March 14, 2022 14:02
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save philipturner/9d349a6a4d3aae3e374ff57d63db6c32 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//////////////////////////////////////////////////////////////////////////////// | |
// Header file | |
//////////////////////////////////////////////////////////////////////////////// | |
#include <metal_stdlib> | |
using namespace metal; | |
struct Double { | |
float lo; | |
float hi; | |
Double(); | |
Double(float lo, float hi); | |
Double(float x); | |
// Comparison | |
bool operator==(Double x) const thread; | |
bool operator==(Double x) const device; | |
bool operator==(Double x) const constant; | |
bool operator!=(Double x) const thread; | |
bool operator!=(Double x) const device; | |
bool operator!=(Double x) const constant; | |
bool operator<(Double x) const thread; | |
bool operator<(Double x) const device; | |
bool operator<(Double x) const constant; | |
bool operator>(Double x) const thread; | |
bool operator>(Double x) const device; | |
bool operator>(Double x) const constant; | |
bool operator<=(Double x) const thread; | |
bool operator<=(Double x) const device; | |
bool operator<=(Double x) const constant; | |
bool operator>=(Double x) const thread; | |
bool operator>=(Double x) const device; | |
bool operator>=(Double x) const constant; | |
// Normalization | |
static Double twoSum(float a, float b); | |
static Double fastTwoSum(float a, float b); | |
// Addition | |
friend Double operator+(Double a, Double b) | |
{ | |
Double lo_hi = twoSum(a.hi, b.hi); | |
lo_hi.lo += a.lo + b.lo; | |
return fastTwoSum(lo_hi.lo, lo_hi.hi); | |
} | |
friend Double operator+(float a, Double b) | |
{ | |
Double lo_hi = twoSum(a, b.hi); | |
lo_hi.lo += b.lo; | |
return fastTwoSum(lo_hi.lo, lo_hi.hi); | |
} | |
friend Double operator+(Double a, float b) | |
{ | |
Double lo_hi = twoSum(a.hi, b); | |
lo_hi.lo += a.lo; | |
return fastTwoSum(lo_hi.lo, lo_hi.hi); | |
} | |
void operator+=(Double x) thread; | |
void operator+=(Double x) device; | |
void operator+=(float x) thread; | |
void operator+=(float x) device; | |
// Subtraction | |
Double operator-() const thread; | |
Double operator-() const device; | |
Double operator-() const constant; | |
friend Double operator-(Double a, Double b) | |
{ | |
Double lo_hi = twoSum(a.hi, -b.hi); | |
lo_hi.lo += a.lo - b.lo; | |
return fastTwoSum(lo_hi.lo, lo_hi.hi); | |
} | |
friend Double operator-(float a, Double b) | |
{ | |
Double lo_hi = twoSum(a, -b.hi); | |
lo_hi.lo -= b.lo; | |
return fastTwoSum(lo_hi.lo, lo_hi.hi); | |
} | |
friend Double operator-(Double a, float b) | |
{ | |
Double lo_hi = twoSum(a.hi, -b); | |
lo_hi.lo += a.lo; | |
return fastTwoSum(lo_hi.lo, lo_hi.hi); | |
} | |
void operator-=(Double x) thread; | |
void operator-=(Double x) device; | |
void operator-=(float x) thread; | |
void operator-=(float x) device; | |
// Multiplication | |
static Double twoProduct(float a, float b); | |
friend Double operator*(Double a, Double b) | |
{ | |
Double lo_hi = twoProduct(a.hi, b.hi); | |
lo_hi.lo = fma(a.lo, b.hi, lo_hi.lo); | |
lo_hi.lo = fma(a.hi, b.lo, lo_hi.lo); | |
lo_hi.lo = fma(a.lo, b.lo, lo_hi.lo); | |
return fastTwoSum(lo_hi.lo, lo_hi.hi); | |
} | |
friend Double operator*(float a, Double b) | |
{ | |
Double lo_hi = twoProduct(a, b.hi); | |
lo_hi.lo = fma(a, b.lo, lo_hi.lo); | |
return fastTwoSum(lo_hi.lo, lo_hi.hi); | |
} | |
friend Double operator*(Double a, float b) | |
{ | |
Double lo_hi = twoProduct(a.hi, b); | |
lo_hi.lo = fma(a.lo, b, lo_hi.lo); | |
return fastTwoSum(lo_hi.lo, lo_hi.hi); | |
} | |
static Double fastMultiply(Double a, float b); | |
static Double fastMultiply(float a, Double b); | |
void operator*=(Double x) thread; | |
void operator*=(Double x) device; | |
void operator*=(float x) thread; | |
void operator*=(float x) device; | |
// Division | |
friend Double operator/(Double a, Double b) | |
{ | |
float xn = precise::divide(1, b.hi); | |
float yn = a.hi * xn; | |
float diff = (a - b * yn).hi; | |
Double prod = twoProduct(xn, diff); | |
return yn + prod; | |
} | |
friend Double operator/(float a, Double b) | |
{ | |
float xn = precise::divide(1, b.hi); | |
float yn = a * xn; | |
float diff = (a - b * yn).hi; | |
Double prod = twoProduct(xn, diff); | |
return yn + prod; | |
} | |
friend Double operator/(Double a, float b) | |
{ | |
float xn = precise::divide(1, b); | |
float yn = a.hi * xn; | |
float diff = (a - twoProduct(b, yn)).hi; | |
Double prod = twoProduct(xn, diff); | |
return yn + prod; | |
} | |
void operator/=(Double x) thread; | |
void operator/=(Double x) device; | |
void operator/=(float x) thread; | |
void operator/=(float x) device; | |
}; | |
#pragma mark - Global Functions | |
Double recip(Double a); | |
Double abs(Double a); | |
Double positive_remainder(Double a, float b); | |
Double sqrt(Double a); | |
Double rsqrt(Double a); | |
Double exp(Double a); | |
Double log(Double a); | |
Double pow(Double a, Double b); | |
// Trigonometric | |
Double __getsin(Double a); | |
Double sincos(Double a, thread Double &cosval); | |
Double sin(Double a); | |
Double cos(Double a); | |
Double tan(Double a); | |
// Emulated fused multiply-add is much slower than doing the operations separately | |
Double precise_fma(Double a, Double b, Double c); | |
//////////////////////////////////////////////////////////////////////////////// | |
// Implementation file | |
//////////////////////////////////////////////////////////////////////////////// | |
// Initialization | |
Double::Double() | |
{ | |
} | |
Double::Double(float lo, float hi) | |
{ | |
this->lo = lo; | |
this->hi = hi; | |
} | |
Double::Double(float x) | |
{ | |
hi = half(x); | |
lo = x - hi; | |
} | |
// Comparison | |
bool Double::operator==(Double x) const thread { return (this->lo == x.lo) && (this->hi == x.hi); } | |
bool Double::operator==(Double x) const device { return (this->lo == x.lo) && (this->hi == x.hi); } | |
bool Double::operator==(Double x) const constant { return (this->lo == x.lo) && (this->hi == x.hi); } | |
bool Double::operator!=(Double x) const thread { return !(*this == x); } | |
bool Double::operator!=(Double x) const device { return !(*this == x); } | |
bool Double::operator!=(Double x) const constant { return !(*this == x); } | |
bool Double::operator<(Double x) const thread { return (this->hi < x.hi) || (this->hi == x.hi && this->lo < x.lo); } | |
bool Double::operator<(Double x) const device { return (this->hi < x.hi) || (this->hi == x.hi && this->lo < x.lo); } | |
bool Double::operator<(Double x) const constant { return (this->hi < x.hi) || (this->hi == x.hi && this->lo < x.lo); } | |
bool Double::operator>(Double x) const thread { return x < *this; } | |
bool Double::operator>(Double x) const device { return x < *this; } | |
bool Double::operator>(Double x) const constant { return x < *this; } | |
bool Double::operator<=(Double x) const thread { return !(*this > x); } | |
bool Double::operator<=(Double x) const device { return !(*this > x); } | |
bool Double::operator<=(Double x) const constant { return !(*this > x); } | |
bool Double::operator>=(Double x) const thread { return !(*this < x); } | |
bool Double::operator>=(Double x) const device { return !(*this < x); } | |
bool Double::operator>=(Double x) const constant { return !(*this < x); } | |
// Normalization | |
Double Double::twoSum(float a, float b) | |
{ | |
float result_hi = a + b; | |
float b_virtual = result_hi - a; | |
float a_virtual = result_hi - b_virtual; | |
float result_lo = (a - a_virtual) + (b - b_virtual); | |
return Double(result_lo, result_hi); | |
} | |
Double Double::fastTwoSum(float a, float b) | |
{ | |
float result_hi = b + a; | |
float a_virtual = result_hi - b; | |
float result_lo = a - a_virtual; | |
return Double(result_lo, result_hi); | |
} | |
// Addition | |
void Double::operator+=(Double x) thread { (*this) = (*this) + x; } | |
void Double::operator+=(Double x) device { (*this) = (*this) + x; } | |
void Double::operator+=(float x) thread { (*this) = (*this) + x; } | |
void Double::operator+=(float x) device { (*this) = (*this) + x; } | |
// Subtraction | |
Double Double::operator-() const thread { return Double(-lo, -hi); } | |
Double Double::operator-() const device { return Double(-lo, -hi); } | |
Double Double::operator-() const constant { return Double(-lo, -hi); } | |
void Double::operator-=(Double x) thread { (*this) = (*this) - x; } | |
void Double::operator-=(Double x) device { (*this) = (*this) - x; } | |
void Double::operator-=(float x) thread { (*this) = (*this) - x; } | |
void Double::operator-=(float x) device { (*this) = (*this) - x; } | |
// Multiplication | |
Double Double::twoProduct(float a, float b) | |
{ | |
float result_hi = a * b; | |
float result_lo = fma(a, b, -result_hi); | |
return Double(result_lo, result_hi); | |
} | |
Double Double::fastMultiply(Double a, float b) | |
{ | |
return Double(a.lo * b, a.hi * b); | |
} | |
Double Double::fastMultiply(float a, Double b) | |
{ | |
return Double(a * b.lo, a * b.hi); | |
} | |
void Double::operator*=(Double x) thread { (*this) = (*this) * x; } | |
void Double::operator*=(Double x) device { (*this) = (*this) * x; } | |
void Double::operator*=(float x) thread { (*this) = (*this) * x; } | |
void Double::operator*=(float x) device { (*this) = (*this) * x; } | |
// Division | |
void Double::operator/=(Double x) thread { (*this) = (*this) / x; } | |
void Double::operator/=(Double x) device { (*this) = (*this) / x; } | |
void Double::operator/=(float x) thread { (*this) = (*this) / x; } | |
void Double::operator/=(float x) device { (*this) = (*this) / x; } | |
#pragma mark - Global Functions | |
Double recip(Double a) | |
{ | |
float xn = precise::divide(1, a.hi); | |
float diff = (1 - a * xn).hi; | |
Double prod = Double::twoProduct(xn, diff); | |
return xn + prod; | |
} | |
Double abs(Double a) | |
{ | |
return (a > 0) ? a : -a; | |
} | |
Double positive_remainder(Double a, float b) | |
{ | |
long quotient = long(floor(a.hi / b)); | |
return a - Double(float(quotient)) * b; | |
} | |
Double sqrt(Double a) | |
{ | |
float xn = precise::rsqrt(a.hi); | |
float yn = a.hi * xn; | |
Double ynsqr = Double::twoProduct(yn, yn); | |
float diff = (a - ynsqr).hi; | |
Double prod = Double::fastMultiply(Double::twoProduct(xn, diff), 0.5); | |
return yn + prod; | |
} | |
Double rsqrt(Double a) | |
{ | |
return recip(sqrt(a)); | |
} | |
Double exp(Double a) | |
{ | |
Double s = 1 + a; | |
Double p = a * a; | |
Double m = 2; | |
Double f = 2; | |
Double t = Double::fastMultiply(p, 0.5); | |
float threshold = 1e-20 * precise::exp(a.hi); | |
while (abs(t.hi) > threshold) | |
{ | |
s += t; | |
p *= a; | |
m += 1; | |
f *= m; | |
t = p / f; | |
} | |
return s + t; | |
} | |
Double log(Double a) | |
{ | |
Double xi = 0; | |
if (xi != 1) | |
{ | |
if (a.hi < 0) | |
{ | |
return NAN; | |
} | |
else | |
{ | |
xi.hi = precise::log(a.hi); | |
xi += exp(-xi) * a - 1; | |
} | |
} | |
return xi; | |
} | |
Double pow(Double a, Double b) | |
{ | |
return exp(log(a) * b); | |
} | |
// Trigonometric | |
Double __getsin(Double a) | |
{ | |
Double x = -a * a; | |
Double s = a; | |
Double p = a; | |
Double m = 1; | |
Double f = 1; | |
float threshold = 1e-20 * abs(a.hi); | |
while (true) | |
{ | |
p *= x; | |
m += 2; | |
f *= m * (m - 1); | |
Double t = p / f; | |
s += t; | |
if (abs(t.hi) < threshold) | |
{ | |
break; | |
} | |
} | |
return s; | |
} | |
Double sincos(Double a, thread Double &cosval) | |
{ | |
if (a.hi == 0) | |
{ | |
cosval = 1; | |
return 0; | |
} | |
Double b = positive_remainder(a, 2 * M_PI_F); | |
Double sinval = __getsin(b); | |
if ((b.hi > M_PI_2_F) && (b.hi < 3 * M_PI_2_F)) | |
{ | |
cosval = -sqrt(1 - sinval * sinval); | |
} | |
else | |
{ | |
cosval = sqrt(1 - sinval * sinval); | |
} | |
return sinval; | |
} | |
Double sin(Double a) | |
{ | |
if (a.hi == 0) | |
{ | |
return 0; | |
} | |
Double b = positive_remainder(a, 2 * M_PI_F); | |
return __getsin(b); | |
} | |
Double cos(Double a) | |
{ | |
if (a.hi == 0) | |
{ | |
return 1; | |
} | |
Double b = positive_remainder(a, 2 * M_PI_F); | |
Double sinval = __getsin(b); | |
if ((b.hi > M_PI_2_F) && (b.hi < 3 * M_PI_2_F)) | |
{ | |
return -sqrt(1 - sinval * sinval); | |
} | |
else | |
{ | |
return sqrt(1 - sinval * sinval); | |
} | |
} | |
Double tan(Double a) | |
{ | |
Double cosval; | |
Double sinval = sincos(a, cosval); | |
return sinval / cosval; | |
} | |
// Emulated fused multiply-add is much slower than doing the operations separately | |
inline float4 mul21(Double a, float b) | |
{ | |
auto r0_q0 = Double::twoProduct(a.lo, b); | |
auto lo_hi = Double::twoProduct(a.hi, b); | |
auto r1_q1 = Double::twoSum (lo_hi.lo, r0_q0.hi); | |
auto r2_r3 = Double::fastTwoSum(r1_q1.hi, lo_hi.hi); | |
return { r0_q0.lo, r1_q1.lo, r2_r3.lo, r2_r3.hi }; | |
} | |
Double precise_fma(Double a, Double b, Double c) | |
{ | |
float4 q = mul21(a, b.lo); | |
float4 p = mul21(a, b.hi); | |
// p3 holds the most significant part of the result | |
auto q0_q1 = Double::fastTwoSum(q[0], q[1]); | |
auto q1_q2 = Double::fastTwoSum(q0_q1.hi, q[2]); | |
q1_q2.lo += q0_q1.lo; | |
auto p0_p1 = Double::fastTwoSum(p[0], p[1]); | |
auto p1_p2 = Double::fastTwoSum(p0_p1.hi, p[2]); | |
p1_p2.lo += p0_p1.lo; | |
// the two products are now just (q1, q2, q3), (p1, p2, p3) | |
auto q1_hi = Double::twoSum(q1_q2.lo, c.hi); | |
auto q2_hi = Double::twoSum(q1_q2.hi, q1_hi.hi); | |
auto p1_hi = Double::twoSum(p1_p2.lo, q2_hi.hi); | |
auto q3_hi = Double::twoSum( q[3], p1_hi.hi); | |
auto p2_hi = Double::twoSum(p1_p2.hi, q3_hi.hi); | |
auto p3_hi = Double::twoSum( p[3], p2_hi.hi); | |
auto lo = c.lo + q1_hi.lo + q2_hi.lo + p1_hi.lo | |
+ q3_hi.lo + p2_hi.lo + p3_hi.lo; | |
return Double::fastTwoSum(lo, p3_hi.hi); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment