Last active
November 14, 2017 08:59
-
-
Save tpgmartin/a49843b3f56c8c4e48574f84deda9d2e to your computer and use it in GitHub Desktop.
Implementation of KNN classifier from scratch using Euclidean distance metric
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# main.py | |
from scipy.spatial import distance | |
from collections import Counter | |
class KNN(): | |
def __init__(self, n_neighbors=1): | |
self.n_neighbors = n_neighbors | |
def fit(self, X_train, y_train): | |
self.X_train = X_train | |
self.y_train = y_train | |
def predict(self, X_test): | |
predictions = [] | |
for row in X_test: | |
prediction = self.__closest(row) | |
predictions.append(prediction) | |
return predictions | |
def __closest(self, row): | |
distances = [] | |
for i in range(len(self.X_train)): | |
dist = distance.euclidean(row, self.X_train[i]) | |
distances.append((self.y_train[i], dist)) | |
sorted_distances = sorted(distances, key=lambda x: x[1]) | |
return self.__vote(sorted_distances) | |
def __vote(self, distances): | |
return Counter(x[0] for x in distances[:self.n_neighbors]).most_common(1)[0][0] | |
# test.py | |
import pytest | |
from main import KNN | |
X_train = [ | |
[0, 0, 0, 0], | |
[1, 1, 1, 1], | |
[1, 1, 1, 1], | |
[2, 2, 2, 2], | |
[2, 2, 2, 2], | |
[2, 2, 2, 2], | |
[2, 2, 2, 2] | |
] | |
y_train = [0, 1, 1, 2, 2, 2, 2] | |
@pytest.mark.parametrize(('n_neighbors'),[1,3,5]) | |
def test_KNN_should_be_initialised_with_n_neighbors(n_neighbors): | |
clf = KNN(n_neighbors) | |
clf.fit(X_train, y_train) | |
assert clf.n_neighbors == n_neighbors | |
@pytest.mark.parametrize(('n_neighbors'),[1,3,5]) | |
def test_should_be_able_to_pass_training_data_to_classifier(n_neighbors): | |
clf = KNN(n_neighbors) | |
clf.fit(X_train, y_train) | |
assert clf.X_train == X_train | |
assert clf.y_train == y_train | |
X_test = [[0, 0, 0, 0]] | |
@pytest.mark.parametrize(('n_neighbors', 'y_test'),[(1, [0]),(3, [1]), (7, [2])]) | |
def test_predict_should_return_label_for_test_data(n_neighbors, y_test): | |
clf = KNN(n_neighbors) | |
clf.fit(X_train, y_train) | |
predictions = clf.predict(X_test) | |
assert predictions == y_test |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment