Created
June 30, 2021 20:50
-
-
Save fanshi118/7e497a05690c2ffdc36dfd6c107009d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import numba | |
import numpy as np | |
import pandas as pd | |
test_data = [ | |
np.random.normal(size=(3, 3)), | |
np.random.normal(size=(25, 25)), | |
np.random.normal(size=(100, 100)), | |
np.random.normal(size=(1000, 1000)), | |
np.random.normal(size=(10000, 10000)), | |
] | |
@numba.vectorize( | |
[ | |
"float32(float32, float32)", | |
"float64(float64, float64)", | |
] | |
) | |
def custom_op_fn(x, y): | |
if x > y: | |
return x | |
else: | |
return y | |
@numba.njit | |
def max_reduce_axis_1(x): | |
x_transpose = np.transpose(x) | |
res = np.full((x.shape[0]), -np.inf, dtype=np.float64) | |
for m in range(x.shape[1]): | |
custom_op_fn(res, x_transpose[m], res) | |
return res | |
# Make sure the underlying Numba function is compiled | |
# custom_op_res = custom_op_fn.reduce(test_data[0], axis=1).reshape(-1, 1) | |
custom_op_res = max_reduce_axis_1(test_data[0]) | |
def np_max_reduce_axis_1(x): | |
return np.maximum.reduce(x, axis=1) | |
np_res = np_max_reduce_axis_1(test_data[0]) | |
# Make sure the custom `Op` and the NumPy graph are equivalent | |
np.testing.assert_array_almost_equal(custom_op_res, np_res) | |
timeit_data = pd.DataFrame( | |
columns=["Numba", "NumPy"], | |
index=pd.Index([], name="data shape"), | |
) | |
def format_result(x): | |
return str(x).split(" per")[0] + f" ({x.loops})" | |
for fn, col in zip( | |
[ | |
max_reduce_axis_1, | |
np_max_reduce_axis_1, | |
], | |
list(timeit_data.columns), | |
): | |
for data in test_data: | |
print(f"Running data with shape={data.shape}") | |
run_time = get_ipython().run_line_magic("timeit", "-o fn(data)") | |
timeit_data.loc[str(data.shape), col] = format_result(run_time) | |
print(timeit_data.to_markdown()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment