Skip to content

Instantly share code, notes, and snippets.

@manodeep
Forked from synapticarbors/Timings.py
Last active August 10, 2023 04:11
Show Gist options
  • Save manodeep/7bbd987762fd4c924413e6a571e9a8c3 to your computer and use it in GitHub Desktop.
Save manodeep/7bbd987762fd4c924413e6a571e9a8c3 to your computer and use it in GitHub Desktop.
Attempt to create a faster version of a Euclidean pairwise distance method in cython using BLAS. Strategy for including BLAS taken from the Tokyo project.
#!python
#cython: boundscheck=False
#cython: wraparound=False
#cython: cdivision=True
import numpy as np
cimport numpy as np
from libc.math cimport sqrt
cdef extern from "cblas_openblas.h":
enum CBLAS_ORDER: CblasRowMajor, CblasColMajor
enum CBLAS_TRANSPOSE: CblasNoTrans, CblasTrans, CblasConjTrans
void lib_dgemm "cblas_dgemm"(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB, int M, int N, int K,
double alpha, double *A, int lda, double *B, int ldb,
double beta, double *C, int ldc) nogil
def pairwise_cython_blas(double[:, ::1] X):
cdef:
int M = X.shape[0]
int N = X.shape[1]
unsigned int i, j, k
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64)
double[:, ::1] C
C = _C
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
C.shape[0], C.shape[0], X.shape[1], -2.0, &X[0,0], X.shape[1],
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1])
for i in range(M-1):
C[i,i] = 0.0
for j in range(i+1,M):
for k in range(N):
C[i,j] += (X[i,k]**2 + X[j,k]**2)
C[i,j] = sqrt(C[i,j])
C[j,i] = C[i,j]
C[M-1,M-1] = 0.0
return _C
def pairwise_cython_blas2(double[:, ::1] X):
cdef:
int M = X.shape[0]
int N = X.shape[1]
unsigned int i, j, k
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64)
double[:, ::1] C
double[::1] sx = np.empty((M,), dtype=np.float64)
C = _C
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
C.shape[0], C.shape[0], X.shape[1], -2.0, &X[0,0], X.shape[1],
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1])
for i in range(M):
sx[i] = 0.0
for k in range(N):
sx[i] += X[i,k]**2
for i in range(M-1):
C[i,i] = 0.0
for j in range(i+1,M):
C[i,j] += (sx[i] + sx[j])
C[i,j] = sqrt(C[i,j])
C[j,i] = C[i,j]
C[M-1,M-1] = 0.0
return _C
## Added by MS to directly compute X.X -> useful for computing correlation functions
def pairwise_cython_blas2_dot(double[:, ::1] X):
cdef:
int M = X.shape[0]
int N = X.shape[1]
unsigned int i, j, k
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64)
double[:, ::1] C
C = _C
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
C.shape[0], C.shape[0], X.shape[1], 1.0, &X[0,0], X.shape[1],
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1])
for i in range(M):
C[i,i] = 0.0
return _C
def pairwise_cython(double[:, ::1] X):
cdef int M = X.shape[0]
cdef int N = X.shape[1]
cdef double tmp, d
cdef double[:, ::1] D = np.empty((M, M), dtype=np.float64)
for i in range(M):
for j in range(M):
d = 0.0
for k in range(N):
tmp = X[i, k] - X[j, k]
d += tmp * tmp
D[i, j] = sqrt(d)
return np.asarray(D)
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import numpy as np
ext_params = {}
ext_params['include_dirs'] = [
'/opt/local/include', # Changed to suit macports
'/System/Library/Frameworks/vecLib.framework/Versions/A/Headers',
np.get_include()]
ext_params['extra_compile_args'] = ["-O2"]
ext_params['extra_link_args'] = []
ext_params['libraries'] = ['blas']
ext_params['library_dirs'] = ['/opt/local/lib'] # Changed to suit macports
ext_modules = [
Extension("distlib", ["distlib.pyx"], **ext_params),
]
setup(
name='distlib',
cmdclass={'build_ext': build_ext},
ext_modules=ext_modules,
)
from __future__ import print_function
import numpy as np
import time
from math import pi
from os.path import dirname, abspath, join as pjoin
import Corrfunc
from Corrfunc.mocks.DDtheta_mocks import DDtheta_mocks
#binfile = pjoin(dirname(abspath(Corrfunc.__file__)), "../mocks/tests/", "angular_bins")
from distlib import pairwise_cython_blas2, pairwise_cython_blas2_dot
import numba
@numba.njit
def hist1d(X, bins):
return np.histogram(X, bins=bins)[0]
N = 10000
nthreads = 1
seed = 42
np.random.seed(seed)
bins = np.linspace(0.1, 90.0, 40)
print(bins)
RA = np.random.uniform(0.0, 2.0*pi, N)*180.0/pi
cos_theta = np.random.uniform(-1.0, 1.0, N)
DEC = 90.0 - np.arccos(cos_theta)*180.0/pi
autocorr = 1
t0 = time.perf_counter()
cf_results = DDtheta_mocks(autocorr, nthreads, bins, RA, DEC, output_thetaavg=False, verbose=True)
t1 = time.perf_counter()
print(cf_results['npairs'])
print("Time taken = %10.2lf seconds\n" % (t1 - t0))
RA1 = RA * pi/180.0
DEC1 = DEC * pi/180.0
pos = np.empty((N, 3), dtype=np.float64)
pos[:, 0] = np.cos(DEC1) * np.cos(RA1)
pos[:, 1] = np.cos(DEC1) * np.sin(RA1)
pos[:, 2] = np.sin(DEC1)
cosbins = np.sort(np.cos(bins*pi/180.0))
chord_sep_bins = np.sort(2.0 - 2.0*cosbins)
t0 = time.perf_counter()
sqr_chord_sep = pairwise_cython_blas2(pos)
t00 = time.perf_counter()
print("Cython_blas2 time taken = %10.4lf seconds" % (t00 - t0))
x = hist1d(sqr_chord_sep, chord_sep_bins)
#print(x)
t1 = time.perf_counter()
np.testing.assert_array_equal(cf_results['npairs'], x)
print("Time taken = %10.2lf seconds (sqr_chord_sep bins)\n" % (t1 - t0))
t0 = time.perf_counter()
costheta = pairwise_cython_blas2_dot(pos)
t00 = time.perf_counter()
print("Cython_blas2_dot time taken = %10.4lf seconds" % (t00 - t0))
x = np.flip(hist1d(costheta, cosbins))
#print(x)
t1 = time.perf_counter()
np.testing.assert_array_equal(cf_results['npairs'], x)
print("Time taken = %10.2lf seconds (cosbins)\n" % (t1 - t0))
t0 = time.perf_counter()
costheta = np.dot(pos, pos.T)
t00 = time.perf_counter()
print("np.dot time taken = %10.4lf seconds" % (t00 - t0))
x = np.flip(hist1d(costheta, cosbins))
#print(x)
t1 = time.perf_counter()
np.testing.assert_array_equal(cf_results['npairs'], x)
print("Time taken = %10.2lf seconds (cosbins)\n" % (t1 - t0))
import numpy as np
from scipy.spatial.distance import cdist
from distlib import pairwise_cython_blas, pairwise_cython, pairwise_cython_blas2
import timeit
a = np.random.random(size=(10000,3))
loop = 1
repeat = 1
funcs = ['cdist(a,a)', 'pairwise_cython(a)', 'pairwise_cython_blas(a)', 'pairwise_cython_blas2(a)']
print("######################################################")
print("# Function Min. time (s) ")
print("######################################################")
for func in funcs:
ts = timeit.repeat(func, globals=globals(), number=loop, repeat=repeat)
print("{0:30s} {1:10.4g}s".format(func, min(ts)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment