public
Last active

Naive O(N**3) 2D np.dot() multithreaded implementation (CPython extension in Cython)

  • Download Gist
cydot.pyx
Cython
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#cython: boundscheck=False, wraparound=False
import numpy as np
cimport numpy as np
 
from cython.parallel cimport prange
 
def dot(np.ndarray[np.float32_t, ndim=2] a not None,
np.ndarray[np.float32_t, ndim=2] b not None,
np.ndarray[np.float32_t, ndim=2] out=None):
"""Naive O(N**3) 2D np.dot() implementation."""
if out is None:
out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype)
if (a.shape[1] != b.shape[0] or
out.shape[0] != a.shape[0] or out.shape[1] != b.shape[1]):
raise ValueError("wrong shape")
 
cdef Py_ssize_t i, j, k
with nogil:
for i in prange(a.shape[0]):
for j in range(b.shape[1]):
out[i,j] = 0
for k in range(a.shape[1]):
out[i,j] += a[i,k] * b[k,j]
return out
cydot.pyxbld
1 2 3 4 5 6 7
from distutils.extension import Extension
 
def make_ext(modname, pyxfilename):
return Extension(name=modname,
sources=[pyxfilename],
extra_compile_args=['-fopenmp'],
extra_link_args=['-fopenmp'])
results.md
Markdown

Without prange() (single-threaded):

python -mtimeit -s'from test_cydot import a,b,out,cydot' 'cydot.dot(a,b,out)'
10 loops, best of 3: 119 msec per loop

With prange() (number of threads == number of cores):

python -mtimeit -s'from test_cydot import a,b,out,cydot' 'cydot.dot(a,b,out)'
10 loops, best of 3: 69.9 msec per loop

numpy.dot() version for comparison:

python -mtimeit -s'from test_cydot import a,b,out,np' 'np.dot(a,b,out)'
100 loops, best of 3: 9.97 msec per loop
test_cydot.py
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
import pyximport; pyximport.install() # pip install cython
import numpy as np
import cydot
 
a = np.random.rand(50, 10000).astype(np.float32)
b = np.random.rand(10000, 60).astype(np.float32)
out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype)
 
def test():
assert np.allclose(np.dot(a,b), cydot.dot(a,b))
 
out2 = out.copy()
out[:] = -1; out2[:] = -2
assert np.allclose(out, -1) and np.allclose(out2, -2)
np.dot(a, b, out); cydot.dot(a, b, out2)
assert np.allclose(out, out2), (out,out2)
 
if __name__=="__main__":
test()

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.