Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Created October 26, 2021 15:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/5f870e349e0e7f2be2ac9b57e18bcb8a to your computer and use it in GitHub Desktop.
Save brandonwillard/5f870e349e0e7f2be2ac9b57e18bcb8a to your computer and use it in GitHub Desktop.
Calling `PyUFuncObject`'s `reduce` method from Numba
import numba as nb
import numpy as np
from numba.core.datamodel.models import StructModel
from numba.core.datamodel.registry import register_default
from numba.core.extending import intrinsic, overload, overload_method
@nb.njit
def add(x, y):
return x + y
custom_vectorize = nb.vectorize([], identity=None, target="cpu")
vec_add = custom_vectorize(add)
#
# [[file:numba/np/ufunc/_internal.c::} PyDUFuncObject;][PyDUFuncObject]]
#
@register_default(nb.np.ufunc.dufunc.DUFunc)
class PyDUFuncModel(StructModel):
"""A model for `DUFunc`."""
_element_type = NotImplemented
def __init__(self):
members = [
("_dispatcher", nb.types.pyobject),
("ufunc", nb.types.pyobject),
("_keepalive", nb.types.pyobject),
]
super(PyDUFuncModel, self).__init__(members)
#
# Do the same for the `PyUFuncObject` returned by `PyDUFuncModel.ufunc`?
#
# [[file:../../../../apps/anaconda3/envs/numba-env/lib/python3.7/site-packages/numpy/core/include/numpy/ufuncobject.h::} PyUFuncObject;][PyUFuncObject]]
#
# @register_default(?)
# class PyUFuncModel(StructModel):
# _element_type = NotImplemented
#
# def __init__(self):
# members = [
# ("ptr", nb.types.pyobject),
# ("obj", nb.types.pyobject),
# ]
# super(PyUFuncModel, self).__init__(members)
@intrinsic
def intr_reduce(typcontext, ft, xt, yt, axist):
sig = nb.types.int64(ft, xt, yt, axist)
def codegen(context, builder, sig, args):
ft = sig.args[0]
f_ir, x_ir, y_ir, axis_ir = args
# Create a usable reference to the underlying `PyDUFuncObject`?
# fn = cgutils.create_struct_proxy(ft)(context, builder)
#
# TODO: Call reduce from one of these references!
#
breakpoint()
return sig, codegen
@overload_method(nb.types.Function, "reduce")
def dufunc_reduce(ft, xt, yt, axist):
if isinstance(ft.typing_key, nb.np.ufunc.dufunc.DUFunc):
def _reduce_impl(ft, xt, yt, axist):
return intr_reduce(ft, xt, yt, axist)
return _reduce_impl
@nb.njit
def test_fn(x, y):
return vec_add.reduce(x, y, 0)
test_fn(np.arange(10), np.arange(10) * 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment