Skip to content

Instantly share code, notes, and snippets.

@philipturner
Created March 14, 2022 14:02
Show Gist options
  • Save philipturner/9d349a6a4d3aae3e374ff57d63db6c32 to your computer and use it in GitHub Desktop.
Save philipturner/9d349a6a4d3aae3e374ff57d63db6c32 to your computer and use it in GitHub Desktop.
////////////////////////////////////////////////////////////////////////////////
// Header file
////////////////////////////////////////////////////////////////////////////////
#include <metal_stdlib>
using namespace metal;
struct Double {
float lo;
float hi;
Double();
Double(float lo, float hi);
Double(float x);
// Comparison
bool operator==(Double x) const thread;
bool operator==(Double x) const device;
bool operator==(Double x) const constant;
bool operator!=(Double x) const thread;
bool operator!=(Double x) const device;
bool operator!=(Double x) const constant;
bool operator<(Double x) const thread;
bool operator<(Double x) const device;
bool operator<(Double x) const constant;
bool operator>(Double x) const thread;
bool operator>(Double x) const device;
bool operator>(Double x) const constant;
bool operator<=(Double x) const thread;
bool operator<=(Double x) const device;
bool operator<=(Double x) const constant;
bool operator>=(Double x) const thread;
bool operator>=(Double x) const device;
bool operator>=(Double x) const constant;
// Normalization
static Double twoSum(float a, float b);
static Double fastTwoSum(float a, float b);
// Addition
friend Double operator+(Double a, Double b)
{
Double lo_hi = twoSum(a.hi, b.hi);
lo_hi.lo += a.lo + b.lo;
return fastTwoSum(lo_hi.lo, lo_hi.hi);
}
friend Double operator+(float a, Double b)
{
Double lo_hi = twoSum(a, b.hi);
lo_hi.lo += b.lo;
return fastTwoSum(lo_hi.lo, lo_hi.hi);
}
friend Double operator+(Double a, float b)
{
Double lo_hi = twoSum(a.hi, b);
lo_hi.lo += a.lo;
return fastTwoSum(lo_hi.lo, lo_hi.hi);
}
void operator+=(Double x) thread;
void operator+=(Double x) device;
void operator+=(float x) thread;
void operator+=(float x) device;
// Subtraction
Double operator-() const thread;
Double operator-() const device;
Double operator-() const constant;
friend Double operator-(Double a, Double b)
{
Double lo_hi = twoSum(a.hi, -b.hi);
lo_hi.lo += a.lo - b.lo;
return fastTwoSum(lo_hi.lo, lo_hi.hi);
}
friend Double operator-(float a, Double b)
{
Double lo_hi = twoSum(a, -b.hi);
lo_hi.lo -= b.lo;
return fastTwoSum(lo_hi.lo, lo_hi.hi);
}
friend Double operator-(Double a, float b)
{
Double lo_hi = twoSum(a.hi, -b);
lo_hi.lo += a.lo;
return fastTwoSum(lo_hi.lo, lo_hi.hi);
}
void operator-=(Double x) thread;
void operator-=(Double x) device;
void operator-=(float x) thread;
void operator-=(float x) device;
// Multiplication
static Double twoProduct(float a, float b);
friend Double operator*(Double a, Double b)
{
Double lo_hi = twoProduct(a.hi, b.hi);
lo_hi.lo = fma(a.lo, b.hi, lo_hi.lo);
lo_hi.lo = fma(a.hi, b.lo, lo_hi.lo);
lo_hi.lo = fma(a.lo, b.lo, lo_hi.lo);
return fastTwoSum(lo_hi.lo, lo_hi.hi);
}
friend Double operator*(float a, Double b)
{
Double lo_hi = twoProduct(a, b.hi);
lo_hi.lo = fma(a, b.lo, lo_hi.lo);
return fastTwoSum(lo_hi.lo, lo_hi.hi);
}
friend Double operator*(Double a, float b)
{
Double lo_hi = twoProduct(a.hi, b);
lo_hi.lo = fma(a.lo, b, lo_hi.lo);
return fastTwoSum(lo_hi.lo, lo_hi.hi);
}
static Double fastMultiply(Double a, float b);
static Double fastMultiply(float a, Double b);
void operator*=(Double x) thread;
void operator*=(Double x) device;
void operator*=(float x) thread;
void operator*=(float x) device;
// Division
friend Double operator/(Double a, Double b)
{
float xn = precise::divide(1, b.hi);
float yn = a.hi * xn;
float diff = (a - b * yn).hi;
Double prod = twoProduct(xn, diff);
return yn + prod;
}
friend Double operator/(float a, Double b)
{
float xn = precise::divide(1, b.hi);
float yn = a * xn;
float diff = (a - b * yn).hi;
Double prod = twoProduct(xn, diff);
return yn + prod;
}
friend Double operator/(Double a, float b)
{
float xn = precise::divide(1, b);
float yn = a.hi * xn;
float diff = (a - twoProduct(b, yn)).hi;
Double prod = twoProduct(xn, diff);
return yn + prod;
}
void operator/=(Double x) thread;
void operator/=(Double x) device;
void operator/=(float x) thread;
void operator/=(float x) device;
};
#pragma mark - Global Functions
Double recip(Double a);
Double abs(Double a);
Double positive_remainder(Double a, float b);
Double sqrt(Double a);
Double rsqrt(Double a);
Double exp(Double a);
Double log(Double a);
Double pow(Double a, Double b);
// Trigonometric
Double __getsin(Double a);
Double sincos(Double a, thread Double &cosval);
Double sin(Double a);
Double cos(Double a);
Double tan(Double a);
// Emulated fused multiply-add is much slower than doing the operations separately
Double precise_fma(Double a, Double b, Double c);
////////////////////////////////////////////////////////////////////////////////
// Implementation file
////////////////////////////////////////////////////////////////////////////////
// Initialization
Double::Double()
{
}
Double::Double(float lo, float hi)
{
this->lo = lo;
this->hi = hi;
}
Double::Double(float x)
{
hi = half(x);
lo = x - hi;
}
// Comparison
bool Double::operator==(Double x) const thread { return (this->lo == x.lo) && (this->hi == x.hi); }
bool Double::operator==(Double x) const device { return (this->lo == x.lo) && (this->hi == x.hi); }
bool Double::operator==(Double x) const constant { return (this->lo == x.lo) && (this->hi == x.hi); }
bool Double::operator!=(Double x) const thread { return !(*this == x); }
bool Double::operator!=(Double x) const device { return !(*this == x); }
bool Double::operator!=(Double x) const constant { return !(*this == x); }
bool Double::operator<(Double x) const thread { return (this->hi < x.hi) || (this->hi == x.hi && this->lo < x.lo); }
bool Double::operator<(Double x) const device { return (this->hi < x.hi) || (this->hi == x.hi && this->lo < x.lo); }
bool Double::operator<(Double x) const constant { return (this->hi < x.hi) || (this->hi == x.hi && this->lo < x.lo); }
bool Double::operator>(Double x) const thread { return x < *this; }
bool Double::operator>(Double x) const device { return x < *this; }
bool Double::operator>(Double x) const constant { return x < *this; }
bool Double::operator<=(Double x) const thread { return !(*this > x); }
bool Double::operator<=(Double x) const device { return !(*this > x); }
bool Double::operator<=(Double x) const constant { return !(*this > x); }
bool Double::operator>=(Double x) const thread { return !(*this < x); }
bool Double::operator>=(Double x) const device { return !(*this < x); }
bool Double::operator>=(Double x) const constant { return !(*this < x); }
// Normalization
Double Double::twoSum(float a, float b)
{
float result_hi = a + b;
float b_virtual = result_hi - a;
float a_virtual = result_hi - b_virtual;
float result_lo = (a - a_virtual) + (b - b_virtual);
return Double(result_lo, result_hi);
}
Double Double::fastTwoSum(float a, float b)
{
float result_hi = b + a;
float a_virtual = result_hi - b;
float result_lo = a - a_virtual;
return Double(result_lo, result_hi);
}
// Addition
void Double::operator+=(Double x) thread { (*this) = (*this) + x; }
void Double::operator+=(Double x) device { (*this) = (*this) + x; }
void Double::operator+=(float x) thread { (*this) = (*this) + x; }
void Double::operator+=(float x) device { (*this) = (*this) + x; }
// Subtraction
Double Double::operator-() const thread { return Double(-lo, -hi); }
Double Double::operator-() const device { return Double(-lo, -hi); }
Double Double::operator-() const constant { return Double(-lo, -hi); }
void Double::operator-=(Double x) thread { (*this) = (*this) - x; }
void Double::operator-=(Double x) device { (*this) = (*this) - x; }
void Double::operator-=(float x) thread { (*this) = (*this) - x; }
void Double::operator-=(float x) device { (*this) = (*this) - x; }
// Multiplication
Double Double::twoProduct(float a, float b)
{
float result_hi = a * b;
float result_lo = fma(a, b, -result_hi);
return Double(result_lo, result_hi);
}
Double Double::fastMultiply(Double a, float b)
{
return Double(a.lo * b, a.hi * b);
}
Double Double::fastMultiply(float a, Double b)
{
return Double(a * b.lo, a * b.hi);
}
void Double::operator*=(Double x) thread { (*this) = (*this) * x; }
void Double::operator*=(Double x) device { (*this) = (*this) * x; }
void Double::operator*=(float x) thread { (*this) = (*this) * x; }
void Double::operator*=(float x) device { (*this) = (*this) * x; }
// Division
void Double::operator/=(Double x) thread { (*this) = (*this) / x; }
void Double::operator/=(Double x) device { (*this) = (*this) / x; }
void Double::operator/=(float x) thread { (*this) = (*this) / x; }
void Double::operator/=(float x) device { (*this) = (*this) / x; }
#pragma mark - Global Functions
Double recip(Double a)
{
float xn = precise::divide(1, a.hi);
float diff = (1 - a * xn).hi;
Double prod = Double::twoProduct(xn, diff);
return xn + prod;
}
Double abs(Double a)
{
return (a > 0) ? a : -a;
}
Double positive_remainder(Double a, float b)
{
long quotient = long(floor(a.hi / b));
return a - Double(float(quotient)) * b;
}
Double sqrt(Double a)
{
float xn = precise::rsqrt(a.hi);
float yn = a.hi * xn;
Double ynsqr = Double::twoProduct(yn, yn);
float diff = (a - ynsqr).hi;
Double prod = Double::fastMultiply(Double::twoProduct(xn, diff), 0.5);
return yn + prod;
}
Double rsqrt(Double a)
{
return recip(sqrt(a));
}
Double exp(Double a)
{
Double s = 1 + a;
Double p = a * a;
Double m = 2;
Double f = 2;
Double t = Double::fastMultiply(p, 0.5);
float threshold = 1e-20 * precise::exp(a.hi);
while (abs(t.hi) > threshold)
{
s += t;
p *= a;
m += 1;
f *= m;
t = p / f;
}
return s + t;
}
Double log(Double a)
{
Double xi = 0;
if (xi != 1)
{
if (a.hi < 0)
{
return NAN;
}
else
{
xi.hi = precise::log(a.hi);
xi += exp(-xi) * a - 1;
}
}
return xi;
}
Double pow(Double a, Double b)
{
return exp(log(a) * b);
}
// Trigonometric
Double __getsin(Double a)
{
Double x = -a * a;
Double s = a;
Double p = a;
Double m = 1;
Double f = 1;
float threshold = 1e-20 * abs(a.hi);
while (true)
{
p *= x;
m += 2;
f *= m * (m - 1);
Double t = p / f;
s += t;
if (abs(t.hi) < threshold)
{
break;
}
}
return s;
}
Double sincos(Double a, thread Double &cosval)
{
if (a.hi == 0)
{
cosval = 1;
return 0;
}
Double b = positive_remainder(a, 2 * M_PI_F);
Double sinval = __getsin(b);
if ((b.hi > M_PI_2_F) && (b.hi < 3 * M_PI_2_F))
{
cosval = -sqrt(1 - sinval * sinval);
}
else
{
cosval = sqrt(1 - sinval * sinval);
}
return sinval;
}
Double sin(Double a)
{
if (a.hi == 0)
{
return 0;
}
Double b = positive_remainder(a, 2 * M_PI_F);
return __getsin(b);
}
Double cos(Double a)
{
if (a.hi == 0)
{
return 1;
}
Double b = positive_remainder(a, 2 * M_PI_F);
Double sinval = __getsin(b);
if ((b.hi > M_PI_2_F) && (b.hi < 3 * M_PI_2_F))
{
return -sqrt(1 - sinval * sinval);
}
else
{
return sqrt(1 - sinval * sinval);
}
}
Double tan(Double a)
{
Double cosval;
Double sinval = sincos(a, cosval);
return sinval / cosval;
}
// Emulated fused multiply-add is much slower than doing the operations separately
inline float4 mul21(Double a, float b)
{
auto r0_q0 = Double::twoProduct(a.lo, b);
auto lo_hi = Double::twoProduct(a.hi, b);
auto r1_q1 = Double::twoSum (lo_hi.lo, r0_q0.hi);
auto r2_r3 = Double::fastTwoSum(r1_q1.hi, lo_hi.hi);
return { r0_q0.lo, r1_q1.lo, r2_r3.lo, r2_r3.hi };
}
Double precise_fma(Double a, Double b, Double c)
{
float4 q = mul21(a, b.lo);
float4 p = mul21(a, b.hi);
// p3 holds the most significant part of the result
auto q0_q1 = Double::fastTwoSum(q[0], q[1]);
auto q1_q2 = Double::fastTwoSum(q0_q1.hi, q[2]);
q1_q2.lo += q0_q1.lo;
auto p0_p1 = Double::fastTwoSum(p[0], p[1]);
auto p1_p2 = Double::fastTwoSum(p0_p1.hi, p[2]);
p1_p2.lo += p0_p1.lo;
// the two products are now just (q1, q2, q3), (p1, p2, p3)
auto q1_hi = Double::twoSum(q1_q2.lo, c.hi);
auto q2_hi = Double::twoSum(q1_q2.hi, q1_hi.hi);
auto p1_hi = Double::twoSum(p1_p2.lo, q2_hi.hi);
auto q3_hi = Double::twoSum( q[3], p1_hi.hi);
auto p2_hi = Double::twoSum(p1_p2.hi, q3_hi.hi);
auto p3_hi = Double::twoSum( p[3], p2_hi.hi);
auto lo = c.lo + q1_hi.lo + q2_hi.lo + p1_hi.lo
+ q3_hi.lo + p2_hi.lo + p3_hi.lo;
return Double::fastTwoSum(lo, p3_hi.hi);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment