Skip to content

Instantly share code, notes, and snippets.

@fanshi118
Created June 30, 2021 20:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fanshi118/7e497a05690c2ffdc36dfd6c107009d2 to your computer and use it in GitHub Desktop.
Save fanshi118/7e497a05690c2ffdc36dfd6c107009d2 to your computer and use it in GitHub Desktop.
import inspect
import numba
import numpy as np
import pandas as pd
test_data = [
np.random.normal(size=(3, 3)),
np.random.normal(size=(25, 25)),
np.random.normal(size=(100, 100)),
np.random.normal(size=(1000, 1000)),
np.random.normal(size=(10000, 10000)),
]
@numba.vectorize(
[
"float32(float32, float32)",
"float64(float64, float64)",
]
)
def custom_op_fn(x, y):
if x > y:
return x
else:
return y
@numba.njit
def max_reduce_axis_1(x):
x_transpose = np.transpose(x)
res = np.full((x.shape[0]), -np.inf, dtype=np.float64)
for m in range(x.shape[1]):
custom_op_fn(res, x_transpose[m], res)
return res
# Make sure the underlying Numba function is compiled
# custom_op_res = custom_op_fn.reduce(test_data[0], axis=1).reshape(-1, 1)
custom_op_res = max_reduce_axis_1(test_data[0])
def np_max_reduce_axis_1(x):
return np.maximum.reduce(x, axis=1)
np_res = np_max_reduce_axis_1(test_data[0])
# Make sure the custom `Op` and the NumPy graph are equivalent
np.testing.assert_array_almost_equal(custom_op_res, np_res)
timeit_data = pd.DataFrame(
columns=["Numba", "NumPy"],
index=pd.Index([], name="data shape"),
)
def format_result(x):
return str(x).split(" per")[0] + f" ({x.loops})"
for fn, col in zip(
[
max_reduce_axis_1,
np_max_reduce_axis_1,
],
list(timeit_data.columns),
):
for data in test_data:
print(f"Running data with shape={data.shape}")
run_time = get_ipython().run_line_magic("timeit", "-o fn(data)")
timeit_data.loc[str(data.shape), col] = format_result(run_time)
print(timeit_data.to_markdown())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment