Skip to content

Instantly share code, notes, and snippets.

@kc611
Last active November 23, 2021 19:05
Show Gist options
  • Save kc611/22760dca36cc9062a401da89ee60ced8 to your computer and use it in GitHub Desktop.
Save kc611/22760dca36cc9062a401da89ee60ced8 to your computer and use it in GitHub Desktop.
from contextlib import contextmanager
import llvmlite.binding as llvm
import numba
import numpy as np
from numba import prange
from aesara.link.utils import compile_function_src
@contextmanager
def use_optimized_cheap_pass(*args, **kwargs):
"""Temporarily replace the cheap optimization pass with a better one.
Usage
=====
with use_optimized_cheap_pass():
# Compile Numba function...
"""
from numba.core.registry import cpu_target
context = cpu_target.target_context._internal_codegen
old_pm = context._mpm_cheap
new_pm = context._module_pass_manager(
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap"
)
context._mpm_cheap = new_pm
try:
yield
finally:
context._mpm_cheap = old_pm
test_array = np.random.random(size=(5, 5, 5))
axis = 1
ndim = 3
result_shape = (5, 5)
total_size = np.prod(test_array.shape)
result_size = int(total_size / test_array.shape[axis])
reduce_elemwise_fn_name = f"_reduce_elemwise"
global_dict = {"prange":prange}
res_cells = ""
arr_cells = ""
count = 0
for i in range(ndim):
if i==axis:
arr_cells += "[i]"
else:
res_cells += f"[idx_arr[{count}]]"
arr_cells += f"[idx_arr[{count}]]"
count = count + 1
reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(res, x, axis, idx_arr):
for i in range(x.shape[axis]):
res{res_cells} = max(res{res_cells}, x{arr_cells})
return
"""
print("Generated function: ")
print(reduce_elemwise_def_src)
reduce_elemwise_fn_py = compile_function_src(reduce_elemwise_def_src, reduce_elemwise_fn_name, global_dict)
reduce_elemwise_fn = numba.njit(reduce_elemwise_fn_py)
@numba.njit(boundscheck=False, fastmath=True)
def careduce_axis(x, axis):
res = np.full(result_shape, -np.inf, dtype=np.float64)
for m in np.ndindex(result_shape):
reduce_elemwise_fn(res, x, axis, m)
return res
with use_optimized_cheap_pass():
numba_res = careduce_axis(test_array, axis)
np_res = np.max(test_array, axis=axis)
assert np.array_equal(numba_res, np_res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment