Skip to content

Instantly share code, notes, and snippets.

@thrasibule
Created March 5, 2017 23:55
Show Gist options
  • Save thrasibule/83ca3e8798b38faf4424ecad95de514c to your computer and use it in GitHub Desktop.
Save thrasibule/83ca3e8798b38faf4424ecad95de514c to your computer and use it in GitHub Desktop.
import numpy as np
import scipy.linalg.blas as blas
import timeit
import os
import ctypes
from ctypes.util import find_library
openblas_lib = ctypes.cdll.LoadLibrary(find_library('openblas'))
def get_num_threads():
return openblas_lib.openblas_get_num_threads()
def set_num_threads(n):
openblas_lib.openblas_set_num_threads(int(n))
seed = 1234
np.random.seed(seed)
N = 1000000
p = 100
X = np.random.random(N * p).reshape((N, p), order='F')
XT = X.T.copy()
true_value=33334547.40257686
old_num_threads = get_num_threads()
def test_numpy():
if not np.isclose(np.trace(X.T.dot(X)), true_value):
raise ValueError()
def test_scipy():
if not np.isclose(np.trace(blas.dsyrk(1., X, trans=1)), true_value):
raise ValueError()
def test_numpy_dgemm():
if not np.isclose(np.trace(XT.dot(X)), true_value):
raise ValueError()
t = timeit.timeit(test_numpy, number=5)
print("Multi threaded computation dsyrk with {} threads: {}".format(old_num_threads, t))
t = timeit.timeit(test_scipy, number=5)
print("Multi threaded computation dsyrk scipy with {} threads: {}".format(old_num_threads, t))
t = timeit.timeit(test_numpy_dgemm, number=5)
print("Multi threaded computation dgemm with {} threads: {}".format(old_num_threads, t))
set_num_threads(1)
t = timeit.timeit(test_numpy, number=5)
print("Non multi-threaded dsyrk computation:{}".format(t))
#set_num_threads(old_num_threads)
t = timeit.timeit(test_scipy, number=5)
print("Non multi-threaded dsyrk scipy computation:{}".format(t))
t = timeit.timeit(test_numpy_dgemm, number=5)
print("Non multi-threaded dgemm computation:{}".format(t))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment