Created
July 16, 2013 10:15
-
-
Save indranilsinharoy/6007508 to your computer and use it in GitHub Desktop.
I was expecting the numba implementation to be faster. However the standard NumPy vectorized implementation seems to be faster. jinc function: standard NumPy implementation
Time taken = 49.9999523163 ms
jinc function: numba 'style' implementation
Time taken = 884.999990463 ms
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division, print_function | |
import numpy as np | |
from numba import double, jit | |
def jinc1(rho): | |
"""Internal implementation of the jinc function over a radial grid `rho` | |
The `jinc(rho)` is defined as `2*(J_1(2*pi*rho)/2*pi*rho)` and `jinc(0)=1.0`. | |
""" | |
mask = rho != 0.0 | |
result = np.ones(rho.shape) | |
result[mask] = j1(2.0*np.pi*rho[mask])/(np.pi*rho[mask]) | |
return result | |
def jinc2(rho): | |
"""Internal implementation of the jinc function over a radial grid `rho` | |
The `jinc(rho)` is defined as `2*(J_1(2*pi*rho)/2*pi*rho)` and `jinc(0)=1.0`. | |
For use with Numba @jit decorator | |
""" | |
M, N = rho.shape | |
result = np.empty((M,N)) | |
for i in range(M): | |
for j in range(N): | |
result[i,j] = j1(2.0*np.pi*rho[i,j])/(np.pi*rho[i,j]) | |
return result | |
fast_jinc = jit(double[:,:](double[:,:]))(jinc2) | |
# Prepare the data | |
x = np.linspace(-1,1,500) | |
y = x.copy() | |
X, Y = np.meshgrid(x,y) | |
rho = np.hypot(X,Y) | |
print("jinc function: standard NumPy implementation") | |
start = time.time() | |
result1 = jinc1(rho) | |
duration = time.time() - start | |
print("Time taken = %s ms"%(duration*1000.0)) | |
print("jinc function: numba 'style' implementation") | |
start, duration = 0.0, 0.0 | |
start = time.time() | |
result2 = fast_jinc(rho) | |
duration = time.time() - start | |
print("Time taken = %s ms"%(duration*1000.0)) | |
# Verify that the results are same | |
np.set_printoptions(precision=4, linewidth=90) | |
print("Result from standard NumPy implementation") | |
print(result1[0:10,0:10]) | |
print("Result from numba 'style' implementation") | |
print(result2[0:10,0:10]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment