Created
August 21, 2014 17:43
-
-
Save cc7768/bc5b8b7b9052708f0c0a to your computer and use it in GitHub Desktop.
Example of numba loops
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numba import jit | |
def f_original(x, y): | |
n = x.size | |
m = y.size | |
z = np.empty((n, m)) | |
for xval in range(n): | |
for yval in range(m): | |
z[xval, yval] = x[xval] * y[yval] | |
return z | |
def f(x, y): | |
n = x.size | |
m = y.size | |
z = np.empty((n, m)) | |
temp = 0 | |
for xval in range(n): | |
for yval in range(m): | |
temp = y[yval] | |
z[xval, yval] = x[xval] * temp | |
return z | |
jitf = jit(f) | |
def outer(x, y): | |
n = x.size | |
m = y.size | |
z = np.empty((n, m)) | |
temp = np.empty(m) | |
for xval in range(n): | |
temp[:] = y | |
z[xval, :] = inner(x[xval], temp, m) | |
return z | |
def inner(xval, y, m): | |
temp = 0. | |
for yval in range(m): | |
temp = y[yval] | |
y[yval] = xval * temp | |
return y | |
jitinner = jit(inner) | |
@jit | |
def jitouter(x, y): | |
n = x.size | |
m = y.size | |
z = np.empty((n, m)) | |
temp = np.empty(m) | |
for xval in range(n): | |
temp[:] = y | |
z[xval, :] = jitinner(x[xval], temp, m) | |
return z | |
@jit | |
def jitouter_pass_inner(x, y, inner_func): | |
n = x.size | |
m = y.size | |
z = np.empty((n, m)) | |
temp = np.empty(m) | |
for xval in range(n): | |
temp[:] = y | |
z[xval, :] = inner_func(x[xval], temp, m) | |
return z | |
one = np.arange(5) | |
two = np.ones(10) * 2 | |
# Make sure they all agree | |
print(f(one, two)) | |
print(jitf(one, two)) | |
print(outer(one, two)) | |
print(jitouter(one, two)) | |
print(jitouter_pass_inner(one, two, jitinner)) | |
big = np.random.randn(200) | |
bigger = np.random.randn(5000) | |
print("This is vanilla python - all loops in one function") | |
%timeit f(big, bigger) # (~500ms) | |
print("This is numba python - all loops in one function") | |
%timeit jitf(big, bigger) # (~1.5ms) | |
print("This is vanilla python - loops broken into 2 functions") | |
%timeit outer(big, bigger) # (~390ms) | |
print("This is numba python - loops broken into 2 functions") | |
%timeit jitouter(big, bigger) # (~3.4ms) | |
print("This is numba python - loops broken into 2 functions with inner as argument") | |
%timeit jitouter_pass_inner(big, bigger, jitinner) # (~3.3ms) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment