Created
June 16, 2017 13:58
-
-
Save IevaZarina/df6586f37521754604e97e803d769a1e to your computer and use it in GitHub Desktop.
Simple supervised KNN classification in Python 2.7
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# knn | |
import math | |
from collections import defaultdict | |
from operator import itemgetter | |
# Some handmade train/test data | |
X_train = [ | |
[1, 1], | |
[1, 2], | |
[2, 4], | |
[3, 5], | |
[1, 0], | |
[0, 0], | |
[1, -2], | |
[-1, 0], | |
[-1, -2], | |
[-2, -2] | |
] | |
y_train = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2] | |
X_test = [ | |
[5, 5], | |
[0, -1], | |
[-5, -5] | |
] | |
y_test = [1, 2, 2] | |
def get_most_common_item(array): | |
count_dict = defaultdict(int) | |
for key in array: | |
count_dict[key] += 1 | |
key, count = max(count_dict.iteritems(), key=itemgetter(1)) | |
return key | |
# https://en.wikipedia.org/wiki/Euclidean_distance | |
def euclidean_dist(A, B): | |
return math.sqrt(sum([(A[i]-B[i])**2 for i, _ in enumerate(A)]) ) | |
def knn(X_train, y_train, X_test, k=1): | |
y_test = [] | |
for test_row in X_test: | |
eucl_dist = [euclidean_dist(train_row, test_row) for train_row in X_train] | |
sorted_eucl_dist = sorted(eucl_dist) | |
closest_knn = [eucl_dist.index(sorted_eucl_dist[i]) for i in xrange(0, k)] if k > 1 else [eucl_dist.index(min(eucl_dist))] | |
closest_labels_knn = [y_train[x] for x in closest_knn] | |
y_test.append(get_most_common_item(closest_labels_knn)) | |
return y_test | |
# Testing Euclidean distance | |
#A = [-1,2,3] | |
#B = [4,0,-3] | |
#print euclidean_dist(A, B) | |
print knn(X_train, y_train, X_test, k=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment