Skip to content

Instantly share code, notes, and snippets.

@synapticarbors
Last active August 10, 2023 01:08
Show Gist options
  • Save synapticarbors/5790459 to your computer and use it in GitHub Desktop.
Save synapticarbors/5790459 to your computer and use it in GitHub Desktop.
Attempt to create a faster version of a Euclidean pairwise distance method in cython using BLAS. Strategy for including BLAS taken from the Tokyo project.
#!python
#cython: boundscheck=False
#cython: wraparound=False
#cython: cdivision=True
import numpy as np
cimport numpy as np
from libc.math cimport sqrt
cdef extern from "cblas.h":
enum CBLAS_ORDER: CblasRowMajor, CblasColMajor
enum CBLAS_TRANSPOSE: CblasNoTrans, CblasTrans, CblasConjTrans
void lib_dgemm "cblas_dgemm"(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB, int M, int N, int K,
double alpha, double *A, int lda, double *B, int ldb,
double beta, double *C, int ldc) nogil
def pairwise_cython_blas(double[:, ::1] X):
cdef:
int M = X.shape[0]
int N = X.shape[1]
unsigned int i, j, k
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64)
double[:, ::1] C
C = _C
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
C.shape[0], C.shape[0], X.shape[1], -2.0, &X[0,0], X.shape[1],
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1])
for i in range(M-1):
C[i,i] = 0.0
for j in range(i+1,M):
for k in range(N):
C[i,j] += (X[i,k]**2 + X[j,k]**2)
C[i,j] = sqrt(C[i,j])
C[j,i] = C[i,j]
C[M-1,M-1] = 0.0
return _C
def pairwise_cython_blas2(double[:, ::1] X):
cdef:
int M = X.shape[0]
int N = X.shape[1]
unsigned int i, j, k
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64)
double[:, ::1] C
double[::1] sx = np.empty((M,), dtype=np.float64)
C = _C
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
C.shape[0], C.shape[0], X.shape[1], -2.0, &X[0,0], X.shape[1],
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1])
for i in range(M):
sx[i] = 0.0
for k in range(N):
sx[i] += X[i,k]**2
for i in range(M-1):
C[i,i] = 0.0
for j in range(i+1,M):
C[i,j] += (sx[i] + sx[j])
C[i,j] = sqrt(C[i,j])
C[j,i] = C[i,j]
C[M-1,M-1] = 0.0
return _C
def pairwise_cython(double[:, ::1] X):
cdef int M = X.shape[0]
cdef int N = X.shape[1]
cdef double tmp, d
cdef double[:, ::1] D = np.empty((M, M), dtype=np.float64)
for i in range(M):
for j in range(M):
d = 0.0
for k in range(N):
tmp = X[i, k] - X[j, k]
d += tmp * tmp
D[i, j] = sqrt(d)
return np.asarray(D)
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import numpy as np
ext_params = {}
ext_params['include_dirs'] = [
'/usr/include',
'/System/Library/Frameworks/vecLib.framework/Versions/A/Headers',
np.get_include()]
ext_params['extra_compile_args'] = ["-O2"]
ext_params['extra_link_args'] = []
ext_params['libraries'] = ['blas']
ext_params['library_dirs'] = ['/usr/lib']
ext_modules = [
Extension("distlib", ["distlib.pyx"], **ext_params),
]
setup(
name='distlib',
cmdclass={'build_ext': build_ext},
ext_modules=ext_modules,
)
In [1]: import numpy as np
In [2]: from scipy.spatial.distance import cdist
In [3]: from distlib import pairwise_cython_blas, pairwise_cython
In [4]: a = np.random.random(size=(1000,3))
In [5]: %timeit cdist(a,a)
100 loops, best of 3: 11.3 ms per loop
In [6]: %timeit pairwise_cython(a)
100 loops, best of 3: 9.54 ms per loop
In [7]: %timeit pairwise_cython_blas(a)
100 loops, best of 3: 13.6 ms per loop
In [8]: %timeit pairwise_cython_blas2(a)
100 loops, best of 3: 13.3 ms per loop
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment