Skip to content

Instantly share code, notes, and snippets.

@jiaaro
Created June 4, 2013 00:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jiaaro/5702629 to your computer and use it in GitHub Desktop.
Save jiaaro/5702629 to your computer and use it in GitHub Desktop.
solution to the homework in "Machine Learning for Humans: K Nearest-Neighbor" at (http://www.jiaaro.com/KNN-for-humans/)
from collections import Counter
K = 3
dataset = [
# weight, color, # seeds, type
[303, 3, 1, "banana"],
[370, 1, 2, "apple"],
[298, 3, 1, "banana"],
[277, 3, 1, "banana"],
[377, 4, 2, "apple"],
[299, 3, 1, "banana"],
[382, 1, 2, "apple"],
[374, 4, 6, "apple"],
[303, 4, 1, "banana"],
[309, 3, 1, "banana"],
[359, 1, 2, "apple"],
[366, 1, 4, "apple"],
[311, 3, 1, "banana"],
[302, 3, 1, "banana"],
[373, 4, 4, "apple"],
[305, 3, 1, "banana"],
[371, 3, 6, "apple"],
]
def normalize_dataset(dataset):
data_columns = zip(*dataset)
normalized_columns = []
for column in data_columns:
max_val = max(column)
min_val = min(column)
val_range = max_val - min_val
normalized_columns.append([(float(val) - min_val) / val_range for val in column])
return zip(*normalized_columns)
def weight_dataset(dataset, weights):
return [
[weight*val for (weight, val) in zip(weights, datapoint)]
for datapoint in dataset]
def distance(fruit1, fruit2):
# first let's get the distance of each parameter
a = fruit1[0] - fruit2[0]
b = fruit1[1] - fruit2[1]
c = fruit1[0] - fruit2[0]
# the distance from point A (fruit1) to point B (fruit2)
d = (a**2 + b**2 + c**2) **0.5
return d
def strip_labels(ds):
return [datapoint[:-1] for datapoint in ds]
def get_label_at_index(i):
return dataset[i][-1]
def readd_labels(ds):
return [datapoint + [get_label_at_index(i)]
for (i, datapoint) in enumerate(ds)]
def prepare_dataset(ds):
ds = normalize_dataset(ds)
return weight_dataset(ds, [1.0, 0.5, 2.0])
def classify(uk):
no_labels = strip_labels(dataset)
ds = readd_labels(prepare_dataset(no_labels))
# this is a hack to normalize and weight the unknown (so it will be on
# roughly the same scale as the dataset's datapoints) it only falls over if
# the unknown is larger than the max or smaller than the min of the dataset.
prepped_uk = prepare_dataset(no_labels + [uk])[-1]
sorted_dataset = sorted(ds, key=lambda fruit: distance(fruit, prepped_uk))
top_k = sorted_dataset[:K]
class_counts = Counter(datapoint[-1] for datapoint in top_k)
classification = max(class_counts, key=lambda cls: class_counts[cls])
return classification
uk1 = [301, 4, 1]
uk2 = [346, 3, 4]
uk3 = [290, 1, 2]
# banana
print classify(uk1)
# apple
print classify(uk2)
# banana
print classify(uk3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment