Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@giuseppebonaccorso
Created August 29, 2017 14:56
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save giuseppebonaccorso/1da2928e9ff986e0e953162fcba3be95 to your computer and use it in GitHub Desktop.
Save giuseppebonaccorso/1da2928e9ff986e0e953162fcba3be95 to your computer and use it in GitHub Desktop.
K-Nearest Neighbors Perfomance Benchmark
from sklearn.datasets import make_blobs
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import time
# Set random seed (for reproducibility)
np.random.seed(1000)
nb_samples = [50, 100, 500, 1000, 2000, 5000, 10000, 50000, 100000, 500000]
nb_features = [10, 25, 50, 75, 100, 150, 250, 300, 400, 500, 800, 1000, 2000]
algorithms = ['brute', 'kd_tree', 'ball_tree']
training_times = {'brute': [],
'kd_tree': [],
'ball_tree': []}
prediction_times = {'brute': [],
'kd_tree': [],
'ball_tree': []}
for i in range(len(nb_samples)):
for algorithm in algorithms:
X, _ = make_blobs(n_samples=nb_samples[i], n_features=nb_features[i],
centers=int(nb_features[i]/10), random_state=1000)
# Create k-Nearest Neighbors instance
nn = NearestNeighbors(algorithm=algorithm, n_jobs=multiprocessing.cpu_count())
# Training
start_time = time.time()
nn.fit(X)
end_time = time.time()
training_times[algorithm].append(end_time - start_time)
# Prediction
xs = np.random.uniform(-1.0, 1.0, size=nb_features[i])
start_time = time.time()
nn.kneighbors(xs.reshape(1, -1), n_neighbors=5)
end_time = time.time()
prediction_times[algorithm].append(end_time - start_time)
# Show the results
fig, ax = plt.subplots(6, 1, figsize=(12, 17))
# Training times
ax[0].set_title('Training time (Brute-force algorithm)')
ax[0].set_xlabel('Number of samples')
ax[0].set_ylabel('Time (seconds)')
ax[0].plot(nb_samples, training_times['brute'])
ax[0].grid()
ax[1].set_title('Training time (KD-Tree algorithm)')
ax[1].set_xlabel('Number of samples')
ax[1].set_ylabel('Time (seconds)')
ax[1].plot(nb_samples, training_times['kd_tree'])
ax[1].grid()
ax[2].set_title('Training time (Ball-Tree algorithm)')
ax[2].set_xlabel('Number of samples')
ax[2].set_ylabel('Time (seconds)')
ax[2].plot(nb_samples, training_times['ball_tree'])
ax[2].grid()
# Prediction times
ax[3].set_title('Prediction time (Brute-force algorithm)')
ax[3].set_xlabel('Number of samples')
ax[3].set_ylabel('Time (seconds)')
ax[3].plot(nb_samples, prediction_times['brute'])
ax[3].grid()
ax[4].set_title('Prediction time (KD-Tree algorithm)')
ax[4].set_xlabel('Number of samples')
ax[4].set_ylabel('Time (seconds)')
ax[4].plot(nb_samples, prediction_times['kd_tree'])
ax[4].grid()
ax[5].set_title('Prediction time (Ball-Tree algorithm)')
ax[5].set_xlabel('Number of samples')
ax[5].set_ylabel('Time (seconds)')
ax[5].plot(nb_samples, prediction_times['ball_tree'])
ax[5].grid()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment