Skip to content

Instantly share code, notes, and snippets.

@TomDLT
Created February 7, 2018 19:45
Show Gist options
  • Save TomDLT/4863afaf2554903dd13f476c141d8a38 to your computer and use it in GitHub Desktop.
Save TomDLT/4863afaf2554903dd13f476c141d8a38 to your computer and use it in GitHub Desktop.
from functools import wraps
import time
import numpy as np
import scipy.sparse as sp
from sklearn.utils.testing import assert_array_equal
class ContextDecorator(object):
def __call__(self, f):
self.f = f
@wraps(f)
def decorated(*args, **kwds):
with self:
return f(*args, **kwds)
return decorated
class TimeIt(ContextDecorator):
def __enter__(self):
self.start = time.time()
def __exit__(self, type, value, traceback):
f_name = self.f.__name__
print("%s:%s %.6f sec" % (f_name, ' ' * (25 - len(f_name)),
time.time() - self.start))
@TimeIt()
def sort_1(graph, debug=False):
"""sort distances to have increasing distances"""
# if it has the same number of neighbors for each samples
row_nnz = np.diff(graph.indptr)
assert row_nnz.max() == row_nnz.min()
n_samples = graph.shape[0]
distances = graph.data.reshape(n_samples, -1)
perm = np.argsort(distances)
perm += np.arange(n_samples)[:, None] * row_nnz[0]
perm = perm.ravel()
graph.data = graph.data[perm]
graph.indices = graph.indices[perm]
if debug:
assert _is_sorted(graph)
return graph
@TimeIt()
def sort_2(graph, debug=False):
"""sort distances to have increasing distances"""
for start, stop in zip(graph.indptr, graph.indptr[1:]):
order = np.argsort(graph.data[start:stop], kind='mergesort')
graph.data[start:stop] = graph.data[start:stop][order]
graph.indices[start:stop] = graph.indices[start:stop][order]
if debug:
assert _is_sorted(graph)
return graph
def _is_sorted(graph):
out_of_order = graph.data[:-1] > graph.data[1:]
return (out_of_order.sum() == out_of_order.take(graph.indptr[1:-1] - 1,
mode='clip').sum())
def bench(graph):
assert not _is_sorted(graph)
graph_1 = sort_1(graph.copy())
assert _is_sorted(graph_1)
graph_2 = sort_2(graph.copy())
assert _is_sorted(graph_2)
assert_array_equal(graph_1.data, graph_2.data)
assert_array_equal(graph_1.indices, graph_2.indices)
# Define dataset
n_samples = 100000
n_neighbors = 10
# Create graph as kneighbors_graph
data = np.random.rand(n_samples * n_neighbors)
indices = np.random.randint(0, n_samples, size=n_samples * n_neighbors)
indptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors)
graph = sp.csr_matrix((data, indices, indptr), shape=(n_samples, n_samples))
bench(graph)
@TomDLT
Copy link
Author

TomDLT commented Feb 7, 2018

# n_samples, n_neighbors = 10000, 10
sort_1:                    0.002019 sec
sort_2:                    0.036385 sec
# n_samples, n_neighbors = 10000, 100
sort_1:                    0.039658 sec
sort_2:                    0.082174 sec
# n_samples, n_neighbors = 100000, 10
sort_1:                    0.019448 sec
sort_2:                    0.344099 sec

@TomDLT
Copy link
Author

TomDLT commented Feb 7, 2018

# n_samples, n_neighbors = 1000, 10000
sort_1:                    0.661219 sec
sort_2:                    0.756222 sec
# n_samples, n_neighbors = 100, 100000
sort_1:                    0.875212 sec
sort_2:                    1.039383 sec

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment