Created
November 6, 2017 17:22
-
-
Save 64lines/c0cd488725ae3c4ec157b90339acdb18 to your computer and use it in GitHub Desktop.
[PYTHON] Plotting K-Neighbors accuracy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
# Setup arrays to store train and test accuracies | |
neighbors = np.arange(1, 9) | |
train_accuracy = np.empty(len(neighbors)) | |
test_accuracy = np.empty(len(neighbors)) | |
# Loop over different values of k | |
for i, k in enumerate(neighbors): | |
# Setup a k-NN Classifier with k neighbors: knn | |
knn = KNeighborsClassifier(n_neighbors=k) | |
# Fit the classifier to the training data | |
knn.fit(X_train, y_train) | |
#Compute accuracy on the training set | |
train_accuracy[i] = knn.score(X_train, y_train) | |
#Compute accuracy on the testing set | |
test_accuracy[i] = knn.score(X_test, y_test) | |
# Generate plot | |
plt.title('k-NN: Varying Number of Neighbors') | |
plt.plot(neighbors, test_accuracy, label = 'Testing Accuracy') | |
plt.plot(neighbors, train_accuracy, label = 'Training Accuracy') | |
plt.legend() | |
plt.xlabel('Number of Neighbors') | |
plt.ylabel('Accuracy') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment