Created
January 22, 2011 04:47
-
-
Save zed/790865 to your computer and use it in GitHub Desktop.
pip install cython && python test_matmul.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# see http://conference.scipy.org/proceedings/SciPy2009/paper_2/ | |
import threading | |
cimport cython | |
import numpy as np | |
cimport numpy as np | |
ctypedef np.float64_t dtype_t | |
def matmul(A, B, out=None): | |
if out is None: out = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) | |
for i in range(A.shape[0]): | |
for j in range(B.shape[1]): | |
s = 0 | |
for k in range(A.shape[1]): | |
s += A[i, k] * B[k, j] | |
out[i, j] = s | |
return out | |
def matmul_dtype(np.ndarray[dtype_t, ndim=2] A not None, | |
np.ndarray[dtype_t, ndim=2] B not None, | |
np.ndarray[dtype_t, ndim=2] out=None): | |
if out is None: | |
out = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) | |
cdef Py_ssize_t i, j, k | |
cdef dtype_t s | |
for i in range(A.shape[0]): | |
for j in range(B.shape[1]): | |
s = 0 | |
for k in range(A.shape[1]): | |
s += A[i, k] * B[k, j] | |
out[i, j] = s | |
return out | |
@cython.boundscheck(False) | |
@cython.wraparound(False) | |
def matmul_dtype_nocheck(np.ndarray[dtype_t, ndim=2] A, | |
np.ndarray[dtype_t, ndim=2] B, | |
np.ndarray[dtype_t, ndim=2] out): | |
cdef Py_ssize_t i, j, k | |
cdef dtype_t s | |
for i in range(A.shape[0]): | |
for j in range(B.shape[1]): | |
s = 0 | |
for k in range(A.shape[1]): | |
s += A[i, k] * B[k, j] | |
out[i, j] = s | |
return out | |
@cython.boundscheck(False) | |
@cython.wraparound(False) | |
def matmul_dtype_nocheck_mem( | |
np.ndarray[dtype_t, ndim=2, mode="c"] A, | |
np.ndarray[dtype_t, ndim=2, mode="fortran"] B, | |
np.ndarray[dtype_t, ndim=2] out): | |
cdef Py_ssize_t i, j, k | |
cdef dtype_t s | |
for i in range(A.shape[0]): | |
for j in range(B.shape[1]): | |
s = 0 | |
for k in range(A.shape[1]): | |
s += A[i, k] * B[k, j] | |
out[i, j] = s | |
return out | |
@cython.boundscheck(False) | |
@cython.wraparound(False) | |
def matmul_dtype_nocheck_fort( | |
np.ndarray[dtype_t, ndim=2] A, | |
np.ndarray[dtype_t, ndim=2] B, | |
np.ndarray[dtype_t, ndim=2] out): | |
if not np.isfortran(B): B = np.asfortranarray(B) | |
cdef Py_ssize_t i, j, k | |
cdef dtype_t s | |
for i in range(A.shape[0]): | |
for j in range(B.shape[1]): | |
s = 0 | |
for k in range(A.shape[1]): | |
s += A[i, k] * B[k, j] | |
out[i, j] = s | |
return out | |
@cython.boundscheck(False) | |
@cython.wraparound(False) | |
def matmul_dtype_nocheck_thread( | |
np.ndarray[dtype_t, ndim=2] A, | |
np.ndarray[dtype_t, ndim=2] B, | |
np.ndarray[dtype_t, ndim=2] out, | |
Py_ssize_t nthreads=8): | |
if not np.isfortran(B): B = np.asfortranarray(B) | |
cdef Py_ssize_t tid | |
ts = [threading.Thread(target=_matmul_slice, | |
args=(tid, nthreads, A, B, out)) | |
for tid in range(nthreads)] | |
for t in ts: t.start() | |
for t in ts: t.join() | |
return out | |
@cython.boundscheck(False) | |
@cython.wraparound(False) | |
def _matmul_slice( | |
Py_ssize_t tid, | |
Py_ssize_t nthreads, | |
np.ndarray[dtype_t, ndim=2] A, | |
np.ndarray[dtype_t, ndim=2] B, | |
np.ndarray[dtype_t, ndim=2] out): | |
cdef Py_ssize_t i, j, k | |
cdef dtype_t s | |
with nogil: # use special for-loop syntax due to nogil | |
for i from tid <= i < A.shape[0] by nthreads: | |
for j in range(B.shape[1]): | |
s = 0 | |
for k in range(A.shape[1]): | |
s += A[i, k] * B[k, j] | |
out[i, j] = s | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyximport; pyximport.install() # $ pip install cython | |
from timeit import default_timer as timer | |
import numpy as np | |
import cymatmul | |
# pure python | |
def matmul(A, B, out=None): | |
if out is None: out = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) | |
for i in range(A.shape[0]): | |
for j in range(B.shape[1]): | |
s = 0 | |
for k in range(A.shape[1]): | |
s += A[i, k] * B[k, j] | |
out[i, j] = s | |
return out | |
def np_dot(a, b, c): | |
# adapt np.dot for test_fun() | |
#NOTE: it provides unfair comparison due to it requires allocating `c` | |
return np.dot(a, b) | |
def test_fun(fun, size=[1000]*3, nrepeat=3, ntimes=1): | |
make_mat = lambda n, m: np.random.uniform(size=(n,m)) | |
def init_arrays(l, m, n): | |
a = make_mat(l, m) | |
#NOTE: 'fortran' order provides the greatest benefit: 4-8 times faster | |
b = np.asfortranarray(make_mat(m,n)) | |
c = np.zeros((a.shape[0], b.shape[1]), dtype=a.dtype) | |
return a, b, c | |
# test that the results are correct | |
a, b, c = init_arrays(*[128]*3) | |
assert np.allclose(fun(a, b, c), np.dot(a, b)), fun.__name__ | |
assert np.allclose(fun(a, b, c), np.dot(a, b)), fun.__name__ | |
# timeit | |
a, b, c = init_arrays(*size) | |
t = float("+inf") | |
for _ in xrange(nrepeat): | |
start = timer() | |
for _ in xrange(ntimes): | |
fun(a, b, c) | |
t = min(t, (timer() - start) / ntimes) | |
return t | |
def main(): | |
print("Pure python:") | |
N = 128 | |
print("%d: %s took us %.2g seconds" % ( | |
N, matmul.__name__, test_fun(matmul, [N]*3))) | |
print("Cython enchanced versions:") | |
N = 128 | |
for f in [ | |
cymatmul.matmul, | |
cymatmul.matmul_dtype, | |
cymatmul.matmul_dtype_nocheck, | |
np_dot, | |
cymatmul.matmul_dtype_nocheck_mem, | |
]: | |
print("%d: %s took us %.2g seconds" % ( | |
N, f.__name__, test_fun(f, [N]*3))) | |
N = 1024 | |
for f in [ | |
cymatmul.matmul_dtype_nocheck, | |
cymatmul.matmul_dtype_nocheck_mem, | |
cymatmul.matmul_dtype_nocheck_fort, | |
cymatmul.matmul_dtype_nocheck_thread, | |
]: | |
print("%d: %s took us %.2g seconds" % ( | |
N, f.__name__, test_fun(f, [N]*3))) | |
if __name__ == '__main__': | |
main() | |
# Pure python: | |
# 128: matmul took us 1.9 seconds | |
# Cython enchanced versions: | |
# 128: matmul took us 1.6 seconds | |
# 128: matmul_dtype took us 0.0061 seconds | |
# 128: matmul_dtype_nocheck took us 0.0025 seconds | |
# 128: np_dot took us 0.0011 seconds | |
# 128: matmul_dtype_nocheck_mem took us 0.0024 seconds | |
# 1024: matmul_dtype_nocheck took us 1.3 seconds | |
# 1024: matmul_dtype_nocheck_mem took us 1.3 seconds | |
# 1024: matmul_dtype_nocheck_fort took us 1.3 seconds | |
# 1024: matmul_dtype_nocheck_thread took us 0.48 seconds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment