Created

Embed URL

HTTPS clone URL

SSH clone URL

You can clone with HTTPS or SSH.

Download Gist

pip install cython && python test_matmul.py

View cymatmul.pyx
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
# see http://conference.scipy.org/proceedings/SciPy2009/paper_2/
import threading
cimport cython
 
import numpy as np
cimport numpy as np
 
ctypedef np.float64_t dtype_t
 
def matmul(A, B, out=None):
if out is None: out = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
 
 
def matmul_dtype(np.ndarray[dtype_t, ndim=2] A not None,
np.ndarray[dtype_t, ndim=2] B not None,
np.ndarray[dtype_t, ndim=2] out=None):
if out is None:
out = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
 
cdef Py_ssize_t i, j, k
cdef dtype_t s
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
 
@cython.boundscheck(False)
@cython.wraparound(False)
def matmul_dtype_nocheck(np.ndarray[dtype_t, ndim=2] A,
np.ndarray[dtype_t, ndim=2] B,
np.ndarray[dtype_t, ndim=2] out):
cdef Py_ssize_t i, j, k
cdef dtype_t s
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
 
@cython.boundscheck(False)
@cython.wraparound(False)
def matmul_dtype_nocheck_mem(
np.ndarray[dtype_t, ndim=2, mode="c"] A,
np.ndarray[dtype_t, ndim=2, mode="fortran"] B,
np.ndarray[dtype_t, ndim=2] out):
cdef Py_ssize_t i, j, k
cdef dtype_t s
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
 
@cython.boundscheck(False)
@cython.wraparound(False)
def matmul_dtype_nocheck_fort(
np.ndarray[dtype_t, ndim=2] A,
np.ndarray[dtype_t, ndim=2] B,
np.ndarray[dtype_t, ndim=2] out):
if not np.isfortran(B): B = np.asfortranarray(B)
cdef Py_ssize_t i, j, k
cdef dtype_t s
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
 
@cython.boundscheck(False)
@cython.wraparound(False)
def matmul_dtype_nocheck_thread(
np.ndarray[dtype_t, ndim=2] A,
np.ndarray[dtype_t, ndim=2] B,
np.ndarray[dtype_t, ndim=2] out,
Py_ssize_t nthreads=8):
if not np.isfortran(B): B = np.asfortranarray(B)
cdef Py_ssize_t tid
ts = [threading.Thread(target=_matmul_slice,
args=(tid, nthreads, A, B, out))
for tid in range(nthreads)]
for t in ts: t.start()
for t in ts: t.join()
return out
 
@cython.boundscheck(False)
@cython.wraparound(False)
def _matmul_slice(
Py_ssize_t tid,
Py_ssize_t nthreads,
np.ndarray[dtype_t, ndim=2] A,
np.ndarray[dtype_t, ndim=2] B,
np.ndarray[dtype_t, ndim=2] out):
cdef Py_ssize_t i, j, k
cdef dtype_t s
with nogil: # use special for-loop syntax due to nogil
for i from tid <= i < A.shape[0] by nthreads:
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
 
View cymatmul.pyx
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
import pyximport; pyximport.install() # $ pip install cython
 
from timeit import default_timer as timer
 
import numpy as np
 
import cymatmul
 
# pure python
def matmul(A, B, out=None):
if out is None: out = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
 
def np_dot(a, b, c):
# adapt np.dot for test_fun()
#NOTE: it provides unfair comparison due to it requires allocating `c`
return np.dot(a, b)
 
 
def test_fun(fun, size=[1000]*3, nrepeat=3, ntimes=1):
make_mat = lambda n, m: np.random.uniform(size=(n,m))
def init_arrays(l, m, n):
a = make_mat(l, m)
#NOTE: 'fortran' order provides the greatest benefit: 4-8 times faster
b = np.asfortranarray(make_mat(m,n))
c = np.zeros((a.shape[0], b.shape[1]), dtype=a.dtype)
return a, b, c
 
# test that the results are correct
a, b, c = init_arrays(*[128]*3)
assert np.allclose(fun(a, b, c), np.dot(a, b)), fun.__name__
assert np.allclose(fun(a, b, c), np.dot(a, b)), fun.__name__
 
# timeit
a, b, c = init_arrays(*size)
t = float("+inf")
for _ in xrange(nrepeat):
start = timer()
for _ in xrange(ntimes):
fun(a, b, c)
t = min(t, (timer() - start) / ntimes)
return t
def main():
print("Pure python:")
N = 128
print("%d: %s took us %.2g seconds" % (
N, matmul.__name__, test_fun(matmul, [N]*3)))
 
print("Cython enchanced versions:")
N = 128
for f in [
cymatmul.matmul,
cymatmul.matmul_dtype,
cymatmul.matmul_dtype_nocheck,
np_dot,
cymatmul.matmul_dtype_nocheck_mem,
]:
print("%d: %s took us %.2g seconds" % (
N, f.__name__, test_fun(f, [N]*3)))
 
N = 1024
for f in [
cymatmul.matmul_dtype_nocheck,
cymatmul.matmul_dtype_nocheck_mem,
cymatmul.matmul_dtype_nocheck_fort,
cymatmul.matmul_dtype_nocheck_thread,
]:
print("%d: %s took us %.2g seconds" % (
N, f.__name__, test_fun(f, [N]*3)))
 
if __name__ == '__main__':
main()
# Pure python:
# 128: matmul took us 1.9 seconds
# Cython enchanced versions:
# 128: matmul took us 1.6 seconds
# 128: matmul_dtype took us 0.0061 seconds
# 128: matmul_dtype_nocheck took us 0.0025 seconds
# 128: np_dot took us 0.0011 seconds
# 128: matmul_dtype_nocheck_mem took us 0.0024 seconds
# 1024: matmul_dtype_nocheck took us 1.3 seconds
# 1024: matmul_dtype_nocheck_mem took us 1.3 seconds
# 1024: matmul_dtype_nocheck_fort took us 1.3 seconds
# 1024: matmul_dtype_nocheck_thread took us 0.48 seconds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.