Created
February 7, 2018 19:45
-
-
Save TomDLT/4863afaf2554903dd13f476c141d8a38 to your computer and use it in GitHub Desktop.
Benchmark for https://github.com/scikit-learn/scikit-learn/pull/10482
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import wraps | |
import time | |
import numpy as np | |
import scipy.sparse as sp | |
from sklearn.utils.testing import assert_array_equal | |
class ContextDecorator(object): | |
def __call__(self, f): | |
self.f = f | |
@wraps(f) | |
def decorated(*args, **kwds): | |
with self: | |
return f(*args, **kwds) | |
return decorated | |
class TimeIt(ContextDecorator): | |
def __enter__(self): | |
self.start = time.time() | |
def __exit__(self, type, value, traceback): | |
f_name = self.f.__name__ | |
print("%s:%s %.6f sec" % (f_name, ' ' * (25 - len(f_name)), | |
time.time() - self.start)) | |
@TimeIt() | |
def sort_1(graph, debug=False): | |
"""sort distances to have increasing distances""" | |
# if it has the same number of neighbors for each samples | |
row_nnz = np.diff(graph.indptr) | |
assert row_nnz.max() == row_nnz.min() | |
n_samples = graph.shape[0] | |
distances = graph.data.reshape(n_samples, -1) | |
perm = np.argsort(distances) | |
perm += np.arange(n_samples)[:, None] * row_nnz[0] | |
perm = perm.ravel() | |
graph.data = graph.data[perm] | |
graph.indices = graph.indices[perm] | |
if debug: | |
assert _is_sorted(graph) | |
return graph | |
@TimeIt() | |
def sort_2(graph, debug=False): | |
"""sort distances to have increasing distances""" | |
for start, stop in zip(graph.indptr, graph.indptr[1:]): | |
order = np.argsort(graph.data[start:stop], kind='mergesort') | |
graph.data[start:stop] = graph.data[start:stop][order] | |
graph.indices[start:stop] = graph.indices[start:stop][order] | |
if debug: | |
assert _is_sorted(graph) | |
return graph | |
def _is_sorted(graph): | |
out_of_order = graph.data[:-1] > graph.data[1:] | |
return (out_of_order.sum() == out_of_order.take(graph.indptr[1:-1] - 1, | |
mode='clip').sum()) | |
def bench(graph): | |
assert not _is_sorted(graph) | |
graph_1 = sort_1(graph.copy()) | |
assert _is_sorted(graph_1) | |
graph_2 = sort_2(graph.copy()) | |
assert _is_sorted(graph_2) | |
assert_array_equal(graph_1.data, graph_2.data) | |
assert_array_equal(graph_1.indices, graph_2.indices) | |
# Define dataset | |
n_samples = 100000 | |
n_neighbors = 10 | |
# Create graph as kneighbors_graph | |
data = np.random.rand(n_samples * n_neighbors) | |
indices = np.random.randint(0, n_samples, size=n_samples * n_neighbors) | |
indptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors) | |
graph = sp.csr_matrix((data, indices, indptr), shape=(n_samples, n_samples)) | |
bench(graph) |
Author
TomDLT
commented
Feb 7, 2018
# n_samples, n_neighbors = 1000, 10000
sort_1: 0.661219 sec
sort_2: 0.756222 sec
# n_samples, n_neighbors = 100, 100000
sort_1: 0.875212 sec
sort_2: 1.039383 sec
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment