Skip to content

Instantly share code, notes, and snippets.

@wolfv
Last active April 18, 2017 04:26
Show Gist options
  • Save wolfv/200f2599452ad7e5bcfc83fed40b77f0 to your computer and use it in GitHub Desktop.
Save wolfv/200f2599452ad7e5bcfc83fed40b77f0 to your computer and use it in GitHub Desktop.
Snippet to generate tests from numpy for xtensor
import numpy as np
dtype_map = {
'bool': 'bool',
'float32': 'float',
'float64': 'double',
'int32': 'int',
'int64': 'long',
'uint32': 'unsigned int',
'uint64': 'unsigned long',
'complex64': 'std::complex<float>',
'complex128': 'std::complex<double>',
}
USE_XTENSOR = False
def get_xtype(arr):
if USE_XTENSOR:
s = "xtensor<" + dtype_map[str(arr.dtype)] + ", " + str(arr.ndim) + ">"
else:
s = "xarray<" + dtype_map[str(arr.dtype)] + ">"
return s
def get_cpp_initlist(arr, name):
s = np.array_str(arr)
s = s.replace('[', '{')
s = s.replace(']', '}')
s = s.replace('j', 'i')
s += ';'
s = s.replace("\n", "\n" + " " * (len(name) + 3))
s = name + " = " + s
return s
def make_test(fn, xt_fn, *args):
i = 0
names = []
res_str = ""
for arg in args:
n = PREF + "arg_" + str(i)
names.append(n)
s = get_cpp_initlist(arg, "xarray<double> " + n)
res_str += s + '\n'
i += 1
res = fn(*args)
xt_fn_str = "auto " + PREF + "res = " + xt_fn + '(' + ', '.join(names) + ');\n'
res_str += xt_fn_str
if isinstance(res, tuple):
i = 0
for el in res:
n = get_xtype(el) + " " + PREF + "expected_" + str(i)
s = get_cpp_initlist(el, n) + '\n'
res_str += s
i += 1
else:
s = get_xtype(res) + " " + get_cpp_initlist(res, PREF + "expected") + '\n'
res_str += s
gtest = "EXPECT_TRUE(all(isclose(" + PREF + "res - " + PREF + "expected)));\n"
res_str += gtest
print(res_str)
if __name__ == '__main__':
global PREF
PREF = ""
make_test(np.dot, "xt::dot", np.random.rand(5,), np.random.rand(5))
PREF = "cross"
make_test(np.cross, "xt::cross", np.random.rand(2,), np.random.rand(3,))
global USE_XTENSOR
USE_XTENSOR = True
PREF = "eig_"
make_test(np.linalg.eig, "xt::linalg::eig", np.random.rand(5, 5))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment