Skip to content

Instantly share code, notes, and snippets.

@64lines
Created November 6, 2017 17:22
Show Gist options
  • Save 64lines/c0cd488725ae3c4ec157b90339acdb18 to your computer and use it in GitHub Desktop.
Save 64lines/c0cd488725ae3c4ec157b90339acdb18 to your computer and use it in GitHub Desktop.
[PYTHON] Plotting K-Neighbors accuracy
import matplotlib.pyplot as plt
# Setup arrays to store train and test accuracies
neighbors = np.arange(1, 9)
train_accuracy = np.empty(len(neighbors))
test_accuracy = np.empty(len(neighbors))
# Loop over different values of k
for i, k in enumerate(neighbors):
# Setup a k-NN Classifier with k neighbors: knn
knn = KNeighborsClassifier(n_neighbors=k)
# Fit the classifier to the training data
knn.fit(X_train, y_train)
#Compute accuracy on the training set
train_accuracy[i] = knn.score(X_train, y_train)
#Compute accuracy on the testing set
test_accuracy[i] = knn.score(X_test, y_test)
# Generate plot
plt.title('k-NN: Varying Number of Neighbors')
plt.plot(neighbors, test_accuracy, label = 'Testing Accuracy')
plt.plot(neighbors, train_accuracy, label = 'Training Accuracy')
plt.legend()
plt.xlabel('Number of Neighbors')
plt.ylabel('Accuracy')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment