-
-
Save manodeep/7bbd987762fd4c924413e6a571e9a8c3 to your computer and use it in GitHub Desktop.
Attempt to create a faster version of a Euclidean pairwise distance method in cython using BLAS. Strategy for including BLAS taken from the Tokyo project.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!python | |
#cython: boundscheck=False | |
#cython: wraparound=False | |
#cython: cdivision=True | |
import numpy as np | |
cimport numpy as np | |
from libc.math cimport sqrt | |
cdef extern from "cblas_openblas.h": | |
enum CBLAS_ORDER: CblasRowMajor, CblasColMajor | |
enum CBLAS_TRANSPOSE: CblasNoTrans, CblasTrans, CblasConjTrans | |
void lib_dgemm "cblas_dgemm"(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, | |
CBLAS_TRANSPOSE TransB, int M, int N, int K, | |
double alpha, double *A, int lda, double *B, int ldb, | |
double beta, double *C, int ldc) nogil | |
def pairwise_cython_blas(double[:, ::1] X): | |
cdef: | |
int M = X.shape[0] | |
int N = X.shape[1] | |
unsigned int i, j, k | |
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64) | |
double[:, ::1] C | |
C = _C | |
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | |
C.shape[0], C.shape[0], X.shape[1], -2.0, &X[0,0], X.shape[1], | |
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1]) | |
for i in range(M-1): | |
C[i,i] = 0.0 | |
for j in range(i+1,M): | |
for k in range(N): | |
C[i,j] += (X[i,k]**2 + X[j,k]**2) | |
C[i,j] = sqrt(C[i,j]) | |
C[j,i] = C[i,j] | |
C[M-1,M-1] = 0.0 | |
return _C | |
def pairwise_cython_blas2(double[:, ::1] X): | |
cdef: | |
int M = X.shape[0] | |
int N = X.shape[1] | |
unsigned int i, j, k | |
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64) | |
double[:, ::1] C | |
double[::1] sx = np.empty((M,), dtype=np.float64) | |
C = _C | |
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | |
C.shape[0], C.shape[0], X.shape[1], -2.0, &X[0,0], X.shape[1], | |
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1]) | |
for i in range(M): | |
sx[i] = 0.0 | |
for k in range(N): | |
sx[i] += X[i,k]**2 | |
for i in range(M-1): | |
C[i,i] = 0.0 | |
for j in range(i+1,M): | |
C[i,j] += (sx[i] + sx[j]) | |
C[i,j] = sqrt(C[i,j]) | |
C[j,i] = C[i,j] | |
C[M-1,M-1] = 0.0 | |
return _C | |
## Added by MS to directly compute X.X -> useful for computing correlation functions | |
def pairwise_cython_blas2_dot(double[:, ::1] X): | |
cdef: | |
int M = X.shape[0] | |
int N = X.shape[1] | |
unsigned int i, j, k | |
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64) | |
double[:, ::1] C | |
C = _C | |
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | |
C.shape[0], C.shape[0], X.shape[1], 1.0, &X[0,0], X.shape[1], | |
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1]) | |
for i in range(M): | |
C[i,i] = 0.0 | |
return _C | |
def pairwise_cython(double[:, ::1] X): | |
cdef int M = X.shape[0] | |
cdef int N = X.shape[1] | |
cdef double tmp, d | |
cdef double[:, ::1] D = np.empty((M, M), dtype=np.float64) | |
for i in range(M): | |
for j in range(M): | |
d = 0.0 | |
for k in range(N): | |
tmp = X[i, k] - X[j, k] | |
d += tmp * tmp | |
D[i, j] = sqrt(d) | |
return np.asarray(D) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from distutils.core import setup | |
from distutils.extension import Extension | |
from Cython.Distutils import build_ext | |
import numpy as np | |
ext_params = {} | |
ext_params['include_dirs'] = [ | |
'/opt/local/include', # Changed to suit macports | |
'/System/Library/Frameworks/vecLib.framework/Versions/A/Headers', | |
np.get_include()] | |
ext_params['extra_compile_args'] = ["-O2"] | |
ext_params['extra_link_args'] = [] | |
ext_params['libraries'] = ['blas'] | |
ext_params['library_dirs'] = ['/opt/local/lib'] # Changed to suit macports | |
ext_modules = [ | |
Extension("distlib", ["distlib.pyx"], **ext_params), | |
] | |
setup( | |
name='distlib', | |
cmdclass={'build_ext': build_ext}, | |
ext_modules=ext_modules, | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import numpy as np | |
import time | |
from math import pi | |
from os.path import dirname, abspath, join as pjoin | |
import Corrfunc | |
from Corrfunc.mocks.DDtheta_mocks import DDtheta_mocks | |
#binfile = pjoin(dirname(abspath(Corrfunc.__file__)), "../mocks/tests/", "angular_bins") | |
from distlib import pairwise_cython_blas2, pairwise_cython_blas2_dot | |
import numba | |
@numba.njit | |
def hist1d(X, bins): | |
return np.histogram(X, bins=bins)[0] | |
N = 10000 | |
nthreads = 1 | |
seed = 42 | |
np.random.seed(seed) | |
bins = np.linspace(0.1, 90.0, 40) | |
print(bins) | |
RA = np.random.uniform(0.0, 2.0*pi, N)*180.0/pi | |
cos_theta = np.random.uniform(-1.0, 1.0, N) | |
DEC = 90.0 - np.arccos(cos_theta)*180.0/pi | |
autocorr = 1 | |
t0 = time.perf_counter() | |
cf_results = DDtheta_mocks(autocorr, nthreads, bins, RA, DEC, output_thetaavg=False, verbose=True) | |
t1 = time.perf_counter() | |
print(cf_results['npairs']) | |
print("Time taken = %10.2lf seconds\n" % (t1 - t0)) | |
RA1 = RA * pi/180.0 | |
DEC1 = DEC * pi/180.0 | |
pos = np.empty((N, 3), dtype=np.float64) | |
pos[:, 0] = np.cos(DEC1) * np.cos(RA1) | |
pos[:, 1] = np.cos(DEC1) * np.sin(RA1) | |
pos[:, 2] = np.sin(DEC1) | |
cosbins = np.sort(np.cos(bins*pi/180.0)) | |
chord_sep_bins = np.sort(2.0 - 2.0*cosbins) | |
t0 = time.perf_counter() | |
sqr_chord_sep = pairwise_cython_blas2(pos) | |
t00 = time.perf_counter() | |
print("Cython_blas2 time taken = %10.4lf seconds" % (t00 - t0)) | |
x = hist1d(sqr_chord_sep, chord_sep_bins) | |
#print(x) | |
t1 = time.perf_counter() | |
np.testing.assert_array_equal(cf_results['npairs'], x) | |
print("Time taken = %10.2lf seconds (sqr_chord_sep bins)\n" % (t1 - t0)) | |
t0 = time.perf_counter() | |
costheta = pairwise_cython_blas2_dot(pos) | |
t00 = time.perf_counter() | |
print("Cython_blas2_dot time taken = %10.4lf seconds" % (t00 - t0)) | |
x = np.flip(hist1d(costheta, cosbins)) | |
#print(x) | |
t1 = time.perf_counter() | |
np.testing.assert_array_equal(cf_results['npairs'], x) | |
print("Time taken = %10.2lf seconds (cosbins)\n" % (t1 - t0)) | |
t0 = time.perf_counter() | |
costheta = np.dot(pos, pos.T) | |
t00 = time.perf_counter() | |
print("np.dot time taken = %10.4lf seconds" % (t00 - t0)) | |
x = np.flip(hist1d(costheta, cosbins)) | |
#print(x) | |
t1 = time.perf_counter() | |
np.testing.assert_array_equal(cf_results['npairs'], x) | |
print("Time taken = %10.2lf seconds (cosbins)\n" % (t1 - t0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.spatial.distance import cdist | |
from distlib import pairwise_cython_blas, pairwise_cython, pairwise_cython_blas2 | |
import timeit | |
a = np.random.random(size=(10000,3)) | |
loop = 1 | |
repeat = 1 | |
funcs = ['cdist(a,a)', 'pairwise_cython(a)', 'pairwise_cython_blas(a)', 'pairwise_cython_blas2(a)'] | |
print("######################################################") | |
print("# Function Min. time (s) ") | |
print("######################################################") | |
for func in funcs: | |
ts = timeit.repeat(func, globals=globals(), number=loop, repeat=repeat) | |
print("{0:30s} {1:10.4g}s".format(func, min(ts))) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment