Created
August 26, 2015 21:02
-
-
Save jan-glx/e39f2b0ef23b10ee7e13 to your computer and use it in GitHub Desktop.
no wonder C is faster
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__author__ = 'jan' | |
import numpy as np | |
import scipy.weave | |
def grad_dist2(ls, x1, x2=None): | |
if x2 is None: | |
x2 = x1 | |
# Rescale. | |
x1 = x1 / ls | |
x2 = x2 / ls | |
N = x1.shape[0] | |
M = x2.shape[0] | |
D = x1.shape[1] | |
gX = np.zeros((x1.shape[0],x2.shape[0],x1.shape[1])) | |
code = \ | |
""" | |
for (int i=0; i<N; i++) | |
for (int j=0; j<M; j++) | |
for (int d=0; d<D; d++) | |
gX(i,j,d) = (2/ls(d))*(x1(i,d) - x2(j,d)); | |
""" | |
try: | |
scipy.weave.inline(code, ['x1','x2','gX','ls','M','N','D'], \ | |
type_converters=scipy.weave.converters.blitz, \ | |
compiler='gcc') | |
except: | |
# The C code weave above is 10x faster than this: | |
for i in xrange(0,x1.shape[0]): | |
gX[i,:,:] = 2*(x1[i,:] - x2[:,:])*(1/ls) | |
return gX | |
def grad_dist3(ls, x1, x2=None): | |
if x2 is None: | |
x2 = x1 | |
# Rescale. | |
x1 = x1 / ls | |
x2 = x2 / ls | |
N = x1.shape[0] | |
M = x2.shape[0] | |
D = x1.shape[1] | |
gX = np.zeros((x1.shape[0],x2.shape[0],x1.shape[1])) | |
# The C code weave above is 10x faster than this: | |
for i in xrange(0,x1.shape[0]): | |
gX[i,:,:] = 2*(x1[i,:] - x2[:,:])*(1/ls) | |
return gX | |
x1=np.random.randn(400,300) | |
x2=np.random.randn(500,300) | |
ls=3.0 | |
gX=grad_dist2(ls, x1, x2) | |
gX2=((x1*2/ls**2)[:,None,:]-(x2*2/ls**2)[None,:,:]) | |
import timeit | |
print(timeit.timeit("gX=grad_dist2(ls, x1, x2)","from __main__ import *",number=10)) | |
print(timeit.timeit("gX=grad_dist3(ls, x1, x2)","from __main__ import *",number=10)) | |
print(timeit.timeit("gX2=((x1*2/ls**2)[:,None,:]-(x2*2/ls**2)[None,:,:])","from __main__ import *",number=10)) | |
print(np.allclose(gX,gX2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
output:
compile trial+fall back: 13.371799713
fall back only: 6.55610998377
numpy only: 2.2905820142
comparison: True