Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active June 24, 2021 21:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/29666a1864d1b9572c41da59d830f4e1 to your computer and use it in GitHub Desktop.
Save brandonwillard/29666a1864d1b9572c41da59d830f4e1 to your computer and use it in GitHub Desktop.
A script that runs comparisons of Aesara implementations of log-sum-exp compiled to Numba
import inspect
import numba
import numpy as np
import pandas as pd
import aesara
import aesara.tensor as at
test_data = [
np.random.normal(size=(3, 3)),
np.random.normal(size=(25, 25)),
np.random.normal(size=(100, 100)),
np.random.normal(size=(1000, 1000)),
np.random.normal(size=(10000, 10000)),
]
@numba.vectorize
def custom_op_fn(x):
if np.isinf(x):
return 0
else:
return x
# Make sure the underlying Numba function is compiled
custom_op_res = custom_op_fn(test_data[0])
X = at.matrix("X")
Y = at.switch(at.isinf(X), 0, X)
aesara_numba_fn = aesara.function([X], Y, mode="NUMBA")
aesara_c_fn = aesara.function([X], Y, mode="FAST_RUN")
fn, *_ = aesara_numba_fn.maker.linker.make_all()
cl_vars = inspect.getclosurevars(fn)
thunk = cl_vars.nonlocals["thunks"][0]
thunk_signature = inspect.signature(thunk)
aesara_numba_direct_fn = thunk_signature.parameters["fgraph_jit"].default
# Make sure the underlying Numba function is compiled
aesara_numba_res = aesara_numba_fn(test_data[0])
aesara_numba_direct_res, = aesara_numba_direct_fn(test_data[0])
# Make sure the custom `Op` and the Aesara graph are equivalent
np.testing.assert_array_almost_equal(custom_op_res, aesara_numba_res)
timeit_data = pd.DataFrame(
columns=["Numba", "Aesara-C", "Aesara-Numba", "Aesara-Numba (direct)"],
index=pd.Index([], name="data shape"),
)
def format_result(x):
return str(x).split(" per")[0] + f" ({x.loops})"
for data in test_data:
print(f"Running data with shape={data.shape}")
numba_time = get_ipython().run_line_magic("timeit", "-o custom_op_fn(data)")
aesara_c_time = get_ipython().run_line_magic("timeit", "-o aesara_c_fn(data)")
aesara_numba_time = get_ipython().run_line_magic(
"timeit", "-o aesara_numba_fn(data)"
)
aesara_numba_direct_time = get_ipython().run_line_magic(
"timeit", "-o aesara_numba_direct_fn(data)"
)
timeit_data.loc[str(data.shape)] = [
format_result(r) for r in [numba_time, aesara_c_time, aesara_numba_time, aesara_numba_direct_time]
]
import numba
import numpy as np
import pandas as pd
import aesara
import aesara.tensor as at
@numba.njit(parallel=True, fastmath=True)
def numba_logsumexp(p, out):
n, m = p.shape
assert len(out) == n
assert out.ndim == 1
assert p.ndim == 2
for i in numba.prange(n):
res = 0
for j in range(m):
res += np.exp(p[i, j])
res = np.log(res)
out[i] = res
@numba.njit(parallel=True, fastmath=True)
def numba_logsumexp_grad(p, out, dout, dp):
n, m = p.shape
assert len(out) == n
assert out.ndim == 1
assert len(dout) == n
assert dout.ndim == 1
assert dp.shape == p.shape
for i in numba.prange(n):
for j in range(m):
dp[i, j] = np.exp(p[i, j] - out[i]) * dout[i]
class LogSumExp(aesara.graph.op.Op):
"""Custom softmax, done through logsumexp"""
itypes = [at.dmatrix]
otypes = [at.dvector]
def perform(self, node, inputs, outputs):
(x,) = inputs
n, m = x.shape
out = np.zeros(n, dtype=x.dtype)
numba_logsumexp(x, out)
outputs[0][0] = out
def grad(self, inputs, grads):
(x,) = inputs
(dout,) = grads
logsumexp = self(x)
return [LogSumExpGrad()(x, logsumexp, dout)]
class LogSumExpGrad(aesara.graph.op.Op):
"""Joint operator"""
itypes = [at.dmatrix, at.dvector, at.dvector]
otypes = [at.dmatrix]
def perform(self, node, inputs, outputs):
p, out, dout = inputs
dp = np.zeros(p.shape, dtype=p.dtype)
numba_logsumexp_grad(p, out, dout, dp)
outputs[0][0] = dp
logsumexp = LogSumExp()
test_data = [
np.random.normal(size=(3, 3)),
np.random.normal(size=(25, 25)),
np.random.normal(size=(100, 100)),
np.random.normal(size=(1000, 1000)),
np.random.normal(size=(10000, 10000)),
]
X = at.matrix("X")
custom_op_fn = aesara.function([X], logsumexp(X))
# Make sure the underlying Numba function is compiled
custom_op_res = custom_op_fn(test_data[0])
def logsumexp2(x, axis=None, keepdims=True):
x_max = at.max(x, axis=axis, keepdims=True)
x_max = at.switch(at.isinf(x_max), 0, x_max)
res = at.log(at.sum(at.exp(x - x_max), axis=axis, keepdims=True)) + x_max
return res if keepdims else res.squeeze()
aesara_numba_fn = aesara.function(
[X], logsumexp2(X, axis=1, keepdims=False), mode="NUMBA"
)
aesara_c_fn = aesara.function(
[X], logsumexp2(X, axis=1, keepdims=False), mode="FAST_RUN"
)
# Make sure the underlying Numba function is compiled
aesara_numba_res = aesara_numba_fn(test_data[0])
# Make sure the custom `Op` and the Aesara graph are equivalent
np.testing.assert_array_almost_equal(custom_op_res, aesara_numba_res)
timeit_data = pd.DataFrame(
columns=["custom `Op`", "Aesara graph (Numba)", "Aesara graph (C)"],
index=pd.Index([], name="data shape"),
)
def format_result(x):
return str(x).split(" per")[0] + f" ({x.loops})"
for data in test_data:
print(f"Running data with shape={data.shape}")
custom_op_time = get_ipython().run_line_magic("timeit", "-o custom_op_fn(data)")
aesara_numba_time = get_ipython().run_line_magic(
"timeit", "-o aesara_numba_fn(data)"
)
aesara_c_time = get_ipython().run_line_magic("timeit", "-o aesara_c_fn(data)")
timeit_data.loc[str(data.shape)] = [
format_result(r) for r in [custom_op_time, aesara_numba_time, aesara_c_time]
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment