Created
June 4, 2013 00:19
-
-
Save jiaaro/5702629 to your computer and use it in GitHub Desktop.
solution to the homework in "Machine Learning for Humans: K Nearest-Neighbor" at (http://www.jiaaro.com/KNN-for-humans/)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
K = 3 | |
dataset = [ | |
# weight, color, # seeds, type | |
[303, 3, 1, "banana"], | |
[370, 1, 2, "apple"], | |
[298, 3, 1, "banana"], | |
[277, 3, 1, "banana"], | |
[377, 4, 2, "apple"], | |
[299, 3, 1, "banana"], | |
[382, 1, 2, "apple"], | |
[374, 4, 6, "apple"], | |
[303, 4, 1, "banana"], | |
[309, 3, 1, "banana"], | |
[359, 1, 2, "apple"], | |
[366, 1, 4, "apple"], | |
[311, 3, 1, "banana"], | |
[302, 3, 1, "banana"], | |
[373, 4, 4, "apple"], | |
[305, 3, 1, "banana"], | |
[371, 3, 6, "apple"], | |
] | |
def normalize_dataset(dataset): | |
data_columns = zip(*dataset) | |
normalized_columns = [] | |
for column in data_columns: | |
max_val = max(column) | |
min_val = min(column) | |
val_range = max_val - min_val | |
normalized_columns.append([(float(val) - min_val) / val_range for val in column]) | |
return zip(*normalized_columns) | |
def weight_dataset(dataset, weights): | |
return [ | |
[weight*val for (weight, val) in zip(weights, datapoint)] | |
for datapoint in dataset] | |
def distance(fruit1, fruit2): | |
# first let's get the distance of each parameter | |
a = fruit1[0] - fruit2[0] | |
b = fruit1[1] - fruit2[1] | |
c = fruit1[0] - fruit2[0] | |
# the distance from point A (fruit1) to point B (fruit2) | |
d = (a**2 + b**2 + c**2) **0.5 | |
return d | |
def strip_labels(ds): | |
return [datapoint[:-1] for datapoint in ds] | |
def get_label_at_index(i): | |
return dataset[i][-1] | |
def readd_labels(ds): | |
return [datapoint + [get_label_at_index(i)] | |
for (i, datapoint) in enumerate(ds)] | |
def prepare_dataset(ds): | |
ds = normalize_dataset(ds) | |
return weight_dataset(ds, [1.0, 0.5, 2.0]) | |
def classify(uk): | |
no_labels = strip_labels(dataset) | |
ds = readd_labels(prepare_dataset(no_labels)) | |
# this is a hack to normalize and weight the unknown (so it will be on | |
# roughly the same scale as the dataset's datapoints) it only falls over if | |
# the unknown is larger than the max or smaller than the min of the dataset. | |
prepped_uk = prepare_dataset(no_labels + [uk])[-1] | |
sorted_dataset = sorted(ds, key=lambda fruit: distance(fruit, prepped_uk)) | |
top_k = sorted_dataset[:K] | |
class_counts = Counter(datapoint[-1] for datapoint in top_k) | |
classification = max(class_counts, key=lambda cls: class_counts[cls]) | |
return classification | |
uk1 = [301, 4, 1] | |
uk2 = [346, 3, 4] | |
uk3 = [290, 1, 2] | |
# banana | |
print classify(uk1) | |
# apple | |
print classify(uk2) | |
# banana | |
print classify(uk3) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment