Skip to content

Instantly share code, notes, and snippets.

@cc7768
Created August 21, 2014 17:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cc7768/bc5b8b7b9052708f0c0a to your computer and use it in GitHub Desktop.
Save cc7768/bc5b8b7b9052708f0c0a to your computer and use it in GitHub Desktop.
Example of numba loops
import numpy as np
from numba import jit
def f_original(x, y):
n = x.size
m = y.size
z = np.empty((n, m))
for xval in range(n):
for yval in range(m):
z[xval, yval] = x[xval] * y[yval]
return z
def f(x, y):
n = x.size
m = y.size
z = np.empty((n, m))
temp = 0
for xval in range(n):
for yval in range(m):
temp = y[yval]
z[xval, yval] = x[xval] * temp
return z
jitf = jit(f)
def outer(x, y):
n = x.size
m = y.size
z = np.empty((n, m))
temp = np.empty(m)
for xval in range(n):
temp[:] = y
z[xval, :] = inner(x[xval], temp, m)
return z
def inner(xval, y, m):
temp = 0.
for yval in range(m):
temp = y[yval]
y[yval] = xval * temp
return y
jitinner = jit(inner)
@jit
def jitouter(x, y):
n = x.size
m = y.size
z = np.empty((n, m))
temp = np.empty(m)
for xval in range(n):
temp[:] = y
z[xval, :] = jitinner(x[xval], temp, m)
return z
@jit
def jitouter_pass_inner(x, y, inner_func):
n = x.size
m = y.size
z = np.empty((n, m))
temp = np.empty(m)
for xval in range(n):
temp[:] = y
z[xval, :] = inner_func(x[xval], temp, m)
return z
one = np.arange(5)
two = np.ones(10) * 2
# Make sure they all agree
print(f(one, two))
print(jitf(one, two))
print(outer(one, two))
print(jitouter(one, two))
print(jitouter_pass_inner(one, two, jitinner))
big = np.random.randn(200)
bigger = np.random.randn(5000)
print("This is vanilla python - all loops in one function")
%timeit f(big, bigger) # (~500ms)
print("This is numba python - all loops in one function")
%timeit jitf(big, bigger) # (~1.5ms)
print("This is vanilla python - loops broken into 2 functions")
%timeit outer(big, bigger) # (~390ms)
print("This is numba python - loops broken into 2 functions")
%timeit jitouter(big, bigger) # (~3.4ms)
print("This is numba python - loops broken into 2 functions with inner as argument")
%timeit jitouter_pass_inner(big, bigger, jitinner) # (~3.3ms)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment