Skip to content

Instantly share code, notes, and snippets.

@sl8r000
Created April 15, 2013 06:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sl8r000/5386112 to your computer and use it in GitHub Desktop.
Save sl8r000/5386112 to your computer and use it in GitHub Desktop.
class KNN(object):
POSSIBLE_K_VALUES = [k for k in range(1,9) if k%2 == 1]
def __init__(self, k=None, weights=None):
self._known_rows = []
self.k = None
self.weights = None
def learn_from_row(self, row):
self._known_rows.append(row)
@classmethod
def general_square(cls, a, b):
if isinstance(a, float) or isinstance(a, int):
return (a - b)**2
else:
return 0 if a == b else 1
@classmethod
def find_k_nearest(cls, weights, known_rows, new_row, missing_column_index, k):
nearest_k = []
distance_cutoff = float('inf')
for row in known_rows:
temp_row = [row[i] for i in range(len(row)) if i != missing_column_index]
temp_new = [new_row[i] for i in range(len(row)) if i != missing_column_index]
distance = sum(weights[i] * KNN.general_square(*pair)
for i, pair in
enumerate(zip(temp_row, temp_new)))
if distance < distance_cutoff:
heapq.heappush(nearest_k, (-distance, row))
if len(nearest_k) > k:
heapq.heappop(nearest_k)
distance_cutoff = -nearest_k[0][0]
return [row for _, row in nearest_k]
def predict_missing_column(self, row_with_missing_column, missing_column_index):
nearest_neighbors = KNN.find_k_nearest(self.weights,
self._known_rows,
row_with_missing_columng,
missing_column_index,
self.k)
return KNN.get_column_consensus(nearest_neighbors, missing_column_index)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment