Skip to content
Create a gist now

Instantly share code, notes, and snippets.

@zed /cydot.pyx

Naive O(N**3) 2D multithreaded implementation (CPython extension in Cython)
#cython: boundscheck=False, wraparound=False
import numpy as np
cimport numpy as np
from cython.parallel cimport prange
def dot(np.ndarray[np.float32_t, ndim=2] a not None,
np.ndarray[np.float32_t, ndim=2] b not None,
np.ndarray[np.float32_t, ndim=2] out=None):
"""Naive O(N**3) 2D implementation."""
if out is None:
out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype)
if (a.shape[1] != b.shape[0] or
out.shape[0] != a.shape[0] or out.shape[1] != b.shape[1]):
raise ValueError("wrong shape")
cdef Py_ssize_t i, j, k
with nogil:
for i in prange(a.shape[0]):
for j in range(b.shape[1]):
out[i,j] = 0
for k in range(a.shape[1]):
out[i,j] += a[i,k] * b[k,j]
return out
from distutils.extension import Extension
def make_ext(modname, pyxfilename):
return Extension(name=modname,

Without prange() (single-threaded):

python -mtimeit -s'from test_cydot import a,b,out,cydot' ',b,out)'
10 loops, best of 3: 119 msec per loop

With prange() (number of threads == number of cores):

python -mtimeit -s'from test_cydot import a,b,out,cydot' ',b,out)'
10 loops, best of 3: 69.9 msec per loop version for comparison:

python -mtimeit -s'from test_cydot import a,b,out,np' ',b,out)'
100 loops, best of 3: 9.97 msec per loop
import pyximport; pyximport.install() # pip install cython
import numpy as np
import cydot
a = np.random.rand(50, 10000).astype(np.float32)
b = np.random.rand(10000, 60).astype(np.float32)
out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype)
def test():
assert np.allclose(,b),,b))
out2 = out.copy()
out[:] = -1; out2[:] = -2
assert np.allclose(out, -1) and np.allclose(out2, -2), b, out);, b, out2)
assert np.allclose(out, out2), (out,out2)
if __name__=="__main__":
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.