Skip to content

Instantly share code, notes, and snippets.

@hugohadfield
Last active March 29, 2021 11:17
Show Gist options
  • Save hugohadfield/bff7c95fefb0937470a27c99d9b084da to your computer and use it in GitHub Desktop.
Save hugohadfield/bff7c95fefb0937470a27c99d9b084da to your computer and use it in GitHub Desktop.
A test of double double arithmetic and numpy + numba
import operator
import pytest
import numba
from numba import types
import numpy as np
##### WE ARE GOING TO CREATE A CUSTOM NUMPY DTYPE #####
np_type = np.dtype([('x', np.float64), ('y', np.float64)])
zero = np.zeros(1, dtype=np_type)[0]
numba_dtype = numba.from_dtype(np_type)
numba_array_type = numba.from_dtype(np_type)[:]
##### THESE ARE THE DOUBLE DOUBLE ARITHMETIC FUNCTIONS THAT OUR TYPE WILL IMPLEMENT #####
@numba.njit
def _two_sum_quick(x, y):
r = x + y
e = y - (r - x)
return r, e
@numba.njit
def _two_sum(x, y):
r = x + y
t = r - x
e = (x - (r - t)) + (y - t)
return r, e
@numba.njit
def _two_difference(x, y):
r = x - y
t = r - x
e = (x - (r - t)) - (y + t)
return r, e
@numba.njit
def _two_product(x, y):
u = x*134217729.0
v = y*134217729.0
s = u - (u - x)
t = v - (v - y)
f = x - s
g = y - t
r = x*y
e = ((s*t - r) + s*g + f*t) + f*g
return r, e
@numba.njit
def mul_double_double(ax, bx, ay, by):
r, e = _two_product(ax, bx)
e = e + ax * by + ay * bx
r, e = _two_sum_quick(r, e)
return r, e
@numba.njit
def rmul_double_double(ax, ay, other):
r, e = _two_product(other, ax)
e = e + other * ay
r, e = _two_sum_quick(r, e)
return r, e
@numba.njit
def add_double_double(ax, bx, ay, by):
r, e = _two_sum(ax, bx)
e = e + ay + by
r, e = _two_sum_quick(r, e)
return r, e
@numba.njit
def radd_double_double(ax, ay, other):
r, e = _two_sum(other, ax)
e = e + ay
r, e = _two_sum_quick(r, e)
return r, e
@numba.njit
def numpy_rmul_double_double(a, b):
r, e = rmul_double_double(a['x'], a['y'], b)
out = np.zeros(1, dtype=numba_dtype)[0]
out['x'] = r
out['y'] = e
return out
@numba.njit
def numpy_mul_double_double(a, b):
r, e = mul_double_double(a['x'], b['x'], a['y'], b['y'])
out = np.zeros(1, dtype=numba_dtype)[0]
out['x'] = r
out['y'] = e
return out
@numba.njit
def numpy_add_double_double(a, b):
r, e = add_double_double(a['x'], b['x'], a['y'], b['y'])
out = np.zeros(1, dtype=numba_dtype)[0]
out['x'] = r
out['y'] = e
return out
@numba.njit
def numpy_radd_double_double(a, b):
r, e = radd_double_double(a['x'], a['y'], b)
out = np.zeros(1, dtype=numba_dtype)[0]
out['x'] = r
out['y'] = e
return out
# This is to allocate zeros for arrays of custom dtype
@numba.njit
def numpy_zeros_array_doubledouble(l):
return np.zeros(l, dtype=numba_dtype)
##### THIS IS WHERE WE DEFINE THE CUSTOM OVERLOADS #####
@numba.extending.overload(operator.mul)
def np_double_double_mul(a, b):
# These are the not array versions
if a == numba_dtype and b == numba_dtype:
def impl(a, b):
return numpy_mul_double_double(a, b)
return impl
elif a == numba_dtype and isinstance(b, types.abstract.Number):
def impl(a, b):
return numpy_rmul_double_double(a, b)
return impl
elif b == numba_dtype and isinstance(a, types.abstract.Number):
def impl(a, b):
return numpy_rmul_double_double(b, a)
return impl
# Now the array versions
elif isinstance(a, type(numba_array_type)):
if a.dtype == numba_dtype and isinstance(b, types.abstract.Number):
def impl(a, b):
output = numpy_zeros_array_doubledouble(a.shape[0])
for i in range(a.shape[0]):
output[i] = numpy_rmul_double_double(a[i], b)
return output
return impl
elif isinstance(b, type(numba_array_type)):
if a.dtype == numba_dtype and b.dtype == numba_dtype:
def impl(a, b):
output = numpy_zeros_array_doubledouble(a.shape[0])
for i in range(a.shape[0]):
output[i] = numpy_mul_double_double(a[i], b[i])
return output
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, type(numba_array_type)):
if b.dtype == numba_dtype:
def impl(a, b):
output = numpy_zeros_array_doubledouble(b.shape[0])
for i in range(b.shape[0]):
output[i] = numpy_rmul_double_double(b[i], a)
return output
return impl
@numba.extending.overload(operator.add)
def np_double_double_add(a, b):
# These are the not array versions
if a == numba_dtype and b == numba_dtype:
def impl(a, b):
return numpy_add_double_double(a, b)
return impl
elif a == numba_dtype and isinstance(b, types.abstract.Number):
def impl(a, b):
return numpy_radd_double_double(a, b)
return impl
elif b == numba_dtype and isinstance(a, types.abstract.Number):
def impl(a, b):
return numpy_radd_double_double(b, a)
return impl
# Now the array versions
elif isinstance(a, type(numba_array_type)):
if a.dtype == numba_dtype and isinstance(b, types.abstract.Number):
def impl(a, b):
output = numpy_zeros_array_doubledouble(a.shape[0])
for i in range(a.shape[0]):
output[i] = numpy_radd_double_double(a[i], b)
return output
return impl
elif isinstance(b, type(numba_array_type)):
if a.dtype == numba_dtype and b.dtype == numba_dtype:
def impl(a, b):
output = numpy_zeros_array_doubledouble(a.shape[0])
for i in range(a.shape[0]):
output[i] = numpy_add_double_double(a[i], b[i])
return output
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, type(numba_array_type)):
if b.dtype == numba_dtype:
def impl(a, b):
output = numpy_zeros_array_doubledouble(b.shape[0])
for i in range(b.shape[0]):
output[i] = numpy_radd_double_double(b[i], a)
return output
return impl
##### THESE ARE SOME TEST UTILITIES #####
@pytest.fixture
def rng():
default_test_seed = 1 # the default seed to start pseudo-random tests
return np.random.default_rng(default_test_seed)
def gen_a_b(rng):
a = np.zeros(1, dtype=np_type)[0]
a['x'] = rng.standard_normal()
a['y'] = 0.0
b = np.zeros(1, dtype=np_type)[0]
b['x'] = rng.standard_normal()
b['y'] = 0.0
return (a, b)
##### THESE ARE THE TESTS #####
class TestDoubleDoubleNumpy:
def test_mul(self, rng):
@numba.njit
def mul_test(a, b):
return 3.0*a*b*2.0
for i in range(1000):
a, b = gen_a_b(rng)
r, e = mul_double_double(a['x'], b['x'], a['y'], b['y'])
r, e = rmul_double_double(r, e, 2.0)
r, e = rmul_double_double(r, e, 3.0)
res = mul_test(a, b)
np.testing.assert_allclose((res['x'], res['y']), (r, e))
def test_array_mul(self, rng):
@numba.njit
def test_array_mul(c, d):
e = 3.0*c
f = d*2.0
return e*f
@numba.njit
def _test_array_mul(c, d):
e = (3.0*c[0], 3.0*c[1])
f = (d[0]*2.0, d[1]*2.0)
return (e[0]*f[0], e[1]*f[1])
for i in range(1000):
a, b = gen_a_b(rng)
c = np.array([a, b])
d = np.array([b, b])
res1 = test_array_mul(c, d)
res2 = _test_array_mul(c, d)
np.testing.assert_allclose((res1[0]['x'], res1[0]['y']), (res2[0]['x'], res2[0]['y']))
def test_array_add(self, rng):
@numba.njit
def test_array_add(c, d):
return 2.0 + c + d + 5
@numba.njit
def _test_array_add(c, d):
return (2.0 + c[0] + d[0] + 5, 2.0 + c[1] + d[1] + 5)
for i in range(1000):
a, b = gen_a_b(rng)
c = np.array([a, b])
d = np.array([b, b])
res1 = test_array_add(c, d)
res2 = _test_array_add(c, d)
np.testing.assert_allclose((res1[0]['x'], res1[0]['y']), (res2[0]['x'], res2[0]['y']))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment