-
-
Save kc611/22760dca36cc9062a401da89ee60ced8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import contextmanager | |
import llvmlite.binding as llvm | |
import numba | |
import numpy as np | |
from numba import prange | |
from aesara.link.utils import compile_function_src | |
@contextmanager | |
def use_optimized_cheap_pass(*args, **kwargs): | |
"""Temporarily replace the cheap optimization pass with a better one. | |
Usage | |
===== | |
with use_optimized_cheap_pass(): | |
# Compile Numba function... | |
""" | |
from numba.core.registry import cpu_target | |
context = cpu_target.target_context._internal_codegen | |
old_pm = context._mpm_cheap | |
new_pm = context._module_pass_manager( | |
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap" | |
) | |
context._mpm_cheap = new_pm | |
try: | |
yield | |
finally: | |
context._mpm_cheap = old_pm | |
test_array = np.random.random(size=(5, 5, 5)) | |
axis = 1 | |
ndim = 3 | |
result_shape = (5, 5) | |
total_size = np.prod(test_array.shape) | |
result_size = int(total_size / test_array.shape[axis]) | |
reduce_elemwise_fn_name = f"_reduce_elemwise" | |
global_dict = {"prange":prange} | |
res_cells = "" | |
arr_cells = "" | |
count = 0 | |
for i in range(ndim): | |
if i==axis: | |
arr_cells += "[i]" | |
else: | |
res_cells += f"[idx_arr[{count}]]" | |
arr_cells += f"[idx_arr[{count}]]" | |
count = count + 1 | |
reduce_elemwise_def_src = f""" | |
def {reduce_elemwise_fn_name}(res, x, axis, idx_arr): | |
for i in range(x.shape[axis]): | |
res{res_cells} = max(res{res_cells}, x{arr_cells}) | |
return | |
""" | |
print("Generated function: ") | |
print(reduce_elemwise_def_src) | |
reduce_elemwise_fn_py = compile_function_src(reduce_elemwise_def_src, reduce_elemwise_fn_name, global_dict) | |
reduce_elemwise_fn = numba.njit(reduce_elemwise_fn_py) | |
@numba.njit(boundscheck=False, fastmath=True) | |
def careduce_axis(x, axis): | |
res = np.full(result_shape, -np.inf, dtype=np.float64) | |
for m in np.ndindex(result_shape): | |
reduce_elemwise_fn(res, x, axis, m) | |
return res | |
with use_optimized_cheap_pass(): | |
numba_res = careduce_axis(test_array, axis) | |
np_res = np.max(test_array, axis=axis) | |
assert np.array_equal(numba_res, np_res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment