-
-
Save douglaspsteen/da7646e957bf1775ca8165c5a72e2b76 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def knn_predict(X_train, X_test, y_train, y_test, k, p): | |
# Counter to help with label voting | |
from collections import Counter | |
# Make predictions on the test data | |
# Need output of 1 prediction per test data point | |
y_hat_test = [] | |
for test_point in X_test: | |
distances = [] | |
for train_point in X_train: | |
distance = minkowski_distance(test_point, train_point, p=p) | |
distances.append(distance) | |
# Store distances in a dataframe | |
df_dists = pd.DataFrame(data=distances, columns=['dist'], | |
index=y_train.index) | |
# Sort distances, and only consider the k closest points | |
df_nn = df_dists.sort_values(by=['dist'], axis=0)[:k] | |
# Create counter object to track the labels of k closest neighbors | |
counter = Counter(y_train[df_nn.index]) | |
# Get most common label of all the nearest neighbors | |
prediction = counter.most_common()[0][0] | |
# Append prediction to output list | |
y_hat_test.append(prediction) | |
return y_hat_test | |
# Make predictions on test dataset | |
y_hat_test = knn_predict(X_train, X_test, y_train, y_test, k=5, p=1) | |
print(y_hat_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment