Skip to content

Instantly share code, notes, and snippets.

@tpgmartin
Last active November 14, 2017 08:59
Show Gist options
  • Save tpgmartin/a49843b3f56c8c4e48574f84deda9d2e to your computer and use it in GitHub Desktop.
Save tpgmartin/a49843b3f56c8c4e48574f84deda9d2e to your computer and use it in GitHub Desktop.
Implementation of KNN classifier from scratch using Euclidean distance metric
# main.py
from scipy.spatial import distance
from collections import Counter
class KNN():
def __init__(self, n_neighbors=1):
self.n_neighbors = n_neighbors
def fit(self, X_train, y_train):
self.X_train = X_train
self.y_train = y_train
def predict(self, X_test):
predictions = []
for row in X_test:
prediction = self.__closest(row)
predictions.append(prediction)
return predictions
def __closest(self, row):
distances = []
for i in range(len(self.X_train)):
dist = distance.euclidean(row, self.X_train[i])
distances.append((self.y_train[i], dist))
sorted_distances = sorted(distances, key=lambda x: x[1])
return self.__vote(sorted_distances)
def __vote(self, distances):
return Counter(x[0] for x in distances[:self.n_neighbors]).most_common(1)[0][0]
# test.py
import pytest
from main import KNN
X_train = [
[0, 0, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1],
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2]
]
y_train = [0, 1, 1, 2, 2, 2, 2]
@pytest.mark.parametrize(('n_neighbors'),[1,3,5])
def test_KNN_should_be_initialised_with_n_neighbors(n_neighbors):
clf = KNN(n_neighbors)
clf.fit(X_train, y_train)
assert clf.n_neighbors == n_neighbors
@pytest.mark.parametrize(('n_neighbors'),[1,3,5])
def test_should_be_able_to_pass_training_data_to_classifier(n_neighbors):
clf = KNN(n_neighbors)
clf.fit(X_train, y_train)
assert clf.X_train == X_train
assert clf.y_train == y_train
X_test = [[0, 0, 0, 0]]
@pytest.mark.parametrize(('n_neighbors', 'y_test'),[(1, [0]),(3, [1]), (7, [2])])
def test_predict_should_return_label_for_test_data(n_neighbors, y_test):
clf = KNN(n_neighbors)
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)
assert predictions == y_test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment