Skip to content

Instantly share code, notes, and snippets.

@albop
Last active September 10, 2018 22:59
Show Gist options
  • Save albop/98da8acb1d7a301b307e8a9cd63fe8ad to your computer and use it in GitHub Desktop.
Save albop/98da8acb1d7a301b307e8a9cd63fe8ad to your computer and use it in GitHub Desktop.
Compares
from numpy import *
from numba import jit
N = 1000000
A = random.random((N,2,2))
B = random.random((N,2,2))
@jit(nopython=True)
def mulvec(A,B):
N = A.shape[0]
C = zeros((N,2,2))
for n in range(N):
C[n,:,:] = A[n,:,:]@B[n,:,:]
return C
@jit(nopython=True)
def mul_22_22(A,B):
C = zeros((2,2))
C[0,0] = A[0,0]*B[0,0] + A[0,1]*B[1,0]
C[0,1] = A[0,0]*B[0,1] + A[0,1]*B[1,1]
C[1,0] = A[1,0]*B[0,0] + A[1,1]*B[1,0]
C[1,1] = A[1,0]*B[0,1] + A[1,1]*B[1,1]
return C
@jit(nopython=True)
def mulvec_22_22(A,B):
N = A.shape[0]
C = zeros((N,2,2))
for n in range(N):
C[n,:,:] = mul_22_22(A[n,:,:], B[n,:,:])
return C
%time C = mulvec(A,B)
%time CC = mulvec_22_22(A,B)
abs(CC - C).max()
Val = tuple([(0,)*i for i in range(10)])
from numba import generated_jit
@generated_jit
def mul_(A,B,sh_A,sh_B):
if (sh_A == numba.typeof( (Val[2], Val[2]) )) and (sh_B == numba.typeof( (Val[2], Val[2]) )):
def fun(A,B,sh_A,sh_B):
C = zeros((2,2))
C[0,0] = A[0,0]*B[0,0] + A[0,1]*B[1,0]
C[0,1] = A[0,0]*B[0,1] + A[0,1]*B[1,1]
C[1,0] = A[1,0]*B[0,0] + A[1,1]*B[1,0]
C[1,1] = A[1,0]*B[0,1] + A[1,1]*B[1,1]
return C
return fun
### one would actually generate the code just in time for all other combinations
@jit
def mulvec2(A,B):
# we actually loose the nopython context here
# because the type of sh_A and sh_B depends on values of A.shape
sh_A = (Val[A.shape[1]], Val[A.shape[2]])
sh_B = (Val[B.shape[1]], Val[B.shape[2]])
N = A.shape[0]
C = zeros((N,2,2))
for n in range(N):
C[n,:,:] = mul_(A[n,:,:],B[n,:,:],sh_A,sh_B)
return C
%time C2 = mulvec2(A,B)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment