Skip to content

Instantly share code, notes, and snippets.

@rth
Created January 20, 2017 16:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rth/2f904eaa165631cd0dba53640f47f652 to your computer and use it in GitHub Desktop.
Save rth/2f904eaa165631cd0dba53640f47f652 to your computer and use it in GitHub Desktop.
A parallel benchmark for parallel pairwise_distances for scikit-learn issue https://github.com/scikit-learn/scikit-learn/issues/8216
from sklearn.metrics import pairwise_distances
from sklearn.externals.joblib import Parallel, delayed, dump, load
import numpy as np
import tempfile
import os
from sklearn.utils import gen_even_slices
import shutil
os.environ['TMPDIR'] = '/dev/shm/'
np.random.seed(99999)
n_dim = 1000
def _mmap_pairwise(X, Y, sl, Z, metric, squared=True):
Z[sl] = pairwise_distances(X[sl], Y, metric, squared=squared)
def mmap_pairwise_distances(X, Y, metric, n_jobs=1, squared=True):
n_samples = X.shape[0]
folder = tempfile.mkdtemp()
X_in_fname = os.path.join(folder, 'X_in')
Y_in_fname = os.path.join(folder, 'Y_in')
Z_fname = os.path.join(folder, 'Z_out')
dump(X, X_in_fname)
dump(Y, Y_in_fname)
X = load(X_in_fname, mmap_mode='r')
Y = load(Y_in_fname, mmap_mode='r')
Z = np.memmap(Z_fname, dtype=X.dtype,
shape=(X.shape[0], Y.shape[0]),
mode='w+')
Parallel(n_jobs=n_jobs)(delayed(_mmap_pairwise)(X, Y, sl, Z, metric, squared=squared)\
for sl in gen_even_slices(n_samples, n_jobs))
shutil.rmtree(folder)
return Z
n_train, n_test = (100000, 1000)
print('\n# n_train={}, n_test={}, n_dim={}\n'.format(
n_train, n_test, n_dim))
X_train = np.random.rand(n_train, n_dim)
X_test = np.random.rand(n_test, n_dim)
for func in [pairwise_distances, mmap_pairwise_distances]:
print('## Backend : ', func.__name__)
for n_jobs in [1, 2]:
print('n_jobs=', n_jobs, ' => ', end='')
%timeit func(X_train, X_test, 'euclidean', n_jobs=n_jobs, squared=False)
# On Intel(R) Xeon(R) CPU E5-2676 v3 @ 2.40GHz this produces
# # n_train=100000, n_test=1000, n_dim=1000
#
# ## Backend : pairwise_distances
# n_jobs= 1 => 1 loop, best of 3: 2.63 s per loop
# n_jobs= 2 => 1 loop, best of 3: 8.43 s per loop
# n_jobs= 4 => 1 loop, best of 3: 9.15 s per loop
# n_jobs= 8 => 1 loop, best of 3: 14.5 s per loop
# ## Backend : mmap_pairwise_distances
# n_jobs= 1 => 1 loop, best of 3: 3.7 s per loop
# n_jobs= 2 => 1 loop, best of 3: 3.43 s per loop
# n_jobs= 4 => 1 loop, best of 3: 2.9 s per loop
# n_jobs= 8 => 1 loop, best of 3: 2.83 s per loop
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment