Skip to content

@zed /cymatmul.pyx
Created

Embed URL

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
pip install cython && python test_matmul.py
# see http://conference.scipy.org/proceedings/SciPy2009/paper_2/
import threading
cimport cython
import numpy as np
cimport numpy as np
ctypedef np.float64_t dtype_t
def matmul(A, B, out=None):
if out is None: out = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
def matmul_dtype(np.ndarray[dtype_t, ndim=2] A not None,
np.ndarray[dtype_t, ndim=2] B not None,
np.ndarray[dtype_t, ndim=2] out=None):
if out is None:
out = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
cdef Py_ssize_t i, j, k
cdef dtype_t s
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
@cython.boundscheck(False)
@cython.wraparound(False)
def matmul_dtype_nocheck(np.ndarray[dtype_t, ndim=2] A,
np.ndarray[dtype_t, ndim=2] B,
np.ndarray[dtype_t, ndim=2] out):
cdef Py_ssize_t i, j, k
cdef dtype_t s
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
@cython.boundscheck(False)
@cython.wraparound(False)
def matmul_dtype_nocheck_mem(
np.ndarray[dtype_t, ndim=2, mode="c"] A,
np.ndarray[dtype_t, ndim=2, mode="fortran"] B,
np.ndarray[dtype_t, ndim=2] out):
cdef Py_ssize_t i, j, k
cdef dtype_t s
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
@cython.boundscheck(False)
@cython.wraparound(False)
def matmul_dtype_nocheck_fort(
np.ndarray[dtype_t, ndim=2] A,
np.ndarray[dtype_t, ndim=2] B,
np.ndarray[dtype_t, ndim=2] out):
if not np.isfortran(B): B = np.asfortranarray(B)
cdef Py_ssize_t i, j, k
cdef dtype_t s
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
@cython.boundscheck(False)
@cython.wraparound(False)
def matmul_dtype_nocheck_thread(
np.ndarray[dtype_t, ndim=2] A,
np.ndarray[dtype_t, ndim=2] B,
np.ndarray[dtype_t, ndim=2] out,
Py_ssize_t nthreads=8):
if not np.isfortran(B): B = np.asfortranarray(B)
cdef Py_ssize_t tid
ts = [threading.Thread(target=_matmul_slice,
args=(tid, nthreads, A, B, out))
for tid in range(nthreads)]
for t in ts: t.start()
for t in ts: t.join()
return out
@cython.boundscheck(False)
@cython.wraparound(False)
def _matmul_slice(
Py_ssize_t tid,
Py_ssize_t nthreads,
np.ndarray[dtype_t, ndim=2] A,
np.ndarray[dtype_t, ndim=2] B,
np.ndarray[dtype_t, ndim=2] out):
cdef Py_ssize_t i, j, k
cdef dtype_t s
with nogil: # use special for-loop syntax due to nogil
for i from tid <= i < A.shape[0] by nthreads:
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
import pyximport; pyximport.install() # $ pip install cython
from timeit import default_timer as timer
import numpy as np
import cymatmul
# pure python
def matmul(A, B, out=None):
if out is None: out = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
for i in range(A.shape[0]):
for j in range(B.shape[1]):
s = 0
for k in range(A.shape[1]):
s += A[i, k] * B[k, j]
out[i, j] = s
return out
def np_dot(a, b, c):
# adapt np.dot for test_fun()
#NOTE: it provides unfair comparison due to it requires allocating `c`
return np.dot(a, b)
def test_fun(fun, size=[1000]*3, nrepeat=3, ntimes=1):
make_mat = lambda n, m: np.random.uniform(size=(n,m))
def init_arrays(l, m, n):
a = make_mat(l, m)
#NOTE: 'fortran' order provides the greatest benefit: 4-8 times faster
b = np.asfortranarray(make_mat(m,n))
c = np.zeros((a.shape[0], b.shape[1]), dtype=a.dtype)
return a, b, c
# test that the results are correct
a, b, c = init_arrays(*[128]*3)
assert np.allclose(fun(a, b, c), np.dot(a, b)), fun.__name__
assert np.allclose(fun(a, b, c), np.dot(a, b)), fun.__name__
# timeit
a, b, c = init_arrays(*size)
t = float("+inf")
for _ in xrange(nrepeat):
start = timer()
for _ in xrange(ntimes):
fun(a, b, c)
t = min(t, (timer() - start) / ntimes)
return t
def main():
print("Pure python:")
N = 128
print("%d: %s took us %.2g seconds" % (
N, matmul.__name__, test_fun(matmul, [N]*3)))
print("Cython enchanced versions:")
N = 128
for f in [
cymatmul.matmul,
cymatmul.matmul_dtype,
cymatmul.matmul_dtype_nocheck,
np_dot,
cymatmul.matmul_dtype_nocheck_mem,
]:
print("%d: %s took us %.2g seconds" % (
N, f.__name__, test_fun(f, [N]*3)))
N = 1024
for f in [
cymatmul.matmul_dtype_nocheck,
cymatmul.matmul_dtype_nocheck_mem,
cymatmul.matmul_dtype_nocheck_fort,
cymatmul.matmul_dtype_nocheck_thread,
]:
print("%d: %s took us %.2g seconds" % (
N, f.__name__, test_fun(f, [N]*3)))
if __name__ == '__main__':
main()
# Pure python:
# 128: matmul took us 1.9 seconds
# Cython enchanced versions:
# 128: matmul took us 1.6 seconds
# 128: matmul_dtype took us 0.0061 seconds
# 128: matmul_dtype_nocheck took us 0.0025 seconds
# 128: np_dot took us 0.0011 seconds
# 128: matmul_dtype_nocheck_mem took us 0.0024 seconds
# 1024: matmul_dtype_nocheck took us 1.3 seconds
# 1024: matmul_dtype_nocheck_mem took us 1.3 seconds
# 1024: matmul_dtype_nocheck_fort took us 1.3 seconds
# 1024: matmul_dtype_nocheck_thread took us 0.48 seconds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.