Created
November 26, 2021 01:16
-
-
Save ravi07bec/7712b9958cd393b46bc53ac26b2f5fd5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example of making predictions | |
from math import sqrt | |
# calculate the Euclidean distance between two vectors | |
def euclidean_distance(row1, row2): | |
distance = 0.0 | |
for i in range(len(row1)-1): | |
distance += (row1[i] - row2[i])**2 | |
return sqrt(distance) | |
# Locate the most similar neighbors | |
def get_neighbors(train, test_row, num_neighbors): | |
distances = list() | |
for train_row in train: | |
dist = euclidean_distance(test_row, train_row) | |
distances.append((train_row, dist)) | |
distances.sort(key=lambda tup: tup[1]) | |
neighbors = list() | |
for i in range(num_neighbors): | |
neighbors.append(distances[i][0]) | |
return neighbors | |
# Make a classification prediction with neighbors | |
def predict_classification(train, test_row, num_neighbors): | |
neighbors = get_neighbors(train, test_row, num_neighbors) | |
print(neighbors) | |
output_values = [row[-1] for row in neighbors] | |
prediction = max(set(output_values), key=output_values.count) | |
return prediction | |
# Test distance function | |
dataset = [[2.7810836,2.550537003,0], | |
[1.465489372,2.362125076,0], | |
[3.396561688,4.400293529,0], | |
[1.38807019,1.850220317,0], | |
[3.06407232,3.005305973,0], | |
[7.627531214,2.759262235,1], | |
[5.332441248,2.088626775,1], | |
[6.922596716,1.77106367,1], | |
[8.675418651,-0.242068655,1], | |
[7.673756466,3.508563011,1]] | |
prediction = predict_classification(dataset, dataset[7], 5) | |
print('Expected %d, Got %d.' % (dataset[7][-1], prediction)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment