Last active
March 31, 2020 13:07
-
-
Save eguiraud/7c7981179d394633a50a906d446f79ef to your computer and use it in GitHub Desktop.
RVecs as numpy arrays via numba in RDataFrame
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numba as nb | |
from numba.types import CPointer, Array, void | |
import ROOT | |
import numpy as np | |
import ctypes | |
def c_friendly_signature(args, ret): | |
"""Take a normalized numba signature, return a C-friendly signature: | |
array types are substituted with a pointer/size pair. | |
""" | |
cargs = [] | |
for arg in args: | |
if isinstance(arg, Array): | |
# convert to T* and size_t input parameters | |
cargs.extend((CPointer(arg.dtype), nb.i4)) | |
else: | |
cargs.append(arg) | |
if isinstance(ret, Array): | |
# convert to T** and size_t* output parameters, return void | |
cargs.extend((CPointer(CPointer(ret.dtype)), CPointer(nb.i4))) | |
cret = void | |
else: | |
cret = ret | |
return cargs, cret | |
# decorator: returns function that takes a python function | |
# and wraps it so that instead of arrays, CPointer and length are | |
# exposed, making it easy to jit it and call it from C++. | |
def cpp_callable_with_array(sig): | |
def decorator(func): | |
args, ret = nb.sigutils.normalize_signature(sig) | |
cargs, cret = c_friendly_signature(args, ret) | |
if cret == void: | |
raise NotImplementedError("RVec as return type not supported yet") | |
else: | |
globals()["jitted_func"] = nb.jit(ret(*args))(func) | |
# TODO all three parameter lists must be programmatically generated | |
c_friendly_fun_str = """def c_friendly_func(vptr, size): | |
arr = nb.carray(vptr, (size,)) | |
return jitted_func(arr) | |
""" | |
exec(c_friendly_fun_str, globals(), locals()) | |
cf = nb.cfunc(cret(*cargs), inline="always")(locals()["c_friendly_func"]) | |
cf_addr = cf.address | |
# TODO reinterpret_cast target type must be programmatically generated | |
# TODO C++ function signature must be programmatically generated | |
ROOT.gInterpreter.Declare("""namespace cpp_callable {{ | |
float sum(ROOT::RVec<float> &v) {{ | |
return reinterpret_cast<float(*)(float*, int)>({})(v.data(), v.size()); | |
}} | |
}}""".format(cf_addr)) | |
return func | |
return decorator | |
def test_ret_float(): | |
@cpp_callable_with_array("f4(f4[:])") | |
def sum(arr): | |
return arr.sum() | |
m = ROOT.RDataFrame(10).Define("v", "ROOT::RVec<float>{1,2,3}").Define("s", "cpp_callable::sum(v)").Mean("s"); | |
assert np.isclose(m.GetValue(), 6) | |
def main(): | |
print("test_ret_float...") | |
test_ret_float() | |
print("done") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When return type is an
RVec
, need something like