Last active
December 15, 2015 08:29
-
-
Save sl8r000/5231189 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
class NearestNeighbor(object): | |
def __init__(self): | |
self._known_rows = [] | |
def learn_from_row(self, row): | |
self._known_rows.append(row) | |
def predict_missing_column( | |
self, row_with_missing_column, missing_column_number): | |
# Given a row with a missing column, try to predict the value of that | |
# column by using the nearest-neighbor scheme on all the rows in | |
# self._known_rows. | |
nearest_neighbor = None | |
nearest_neighbor_distance = float('inf') | |
for row in self._known_rows: | |
distance_sq = 0 | |
for i in [x for x in range(len(row)) if x != missing_column_number]: | |
a = row[i] | |
b = row_with_missing_column[i] | |
if ((isinstance(a, float) or isinstance(a, int)) | |
and (isinstance(b, float) or isinstance(b, int))): | |
distance_sq += (a - b) ** 2 | |
else: | |
distance_sq += 1 | |
distance = math.sqrt(distance_sq) | |
if distance < nearest_neighbor_distance: | |
nearest_neighbor_distance = distance | |
nearest_neighbor = row | |
return nearest_neighbor[missing_column_number] | |
if __name__ == '__main__': | |
# Simple example. | |
model = NearestNeighbor() | |
row_0 = [27.2, 'red', 1992, 0] | |
row_1 = [15.1, 'blue', 1999, 1] | |
row_2 = [18.0, 'red', 1989, 0] | |
for row in [row_0, row_1, row_2]: | |
model.learn_from_row(row) | |
print model.predict_missing_column([23.0, 'blue', 2001], 3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment