Created
August 21, 2017 11:34
-
-
Save jeongukjae/2725e018a3d06cccf1a8d3bf8c14c477 to your computer and use it in GitHub Desktop.
image classifier using KNN algorithm and cifar 10 dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# image classifier using KNN and cifar 10 dataset | |
import numpy as np | |
import pickle | |
# sort dictionary by key | |
def sort_key(dictionary): | |
keys = list(dictionary.keys()) | |
keys.sort() | |
sorted_dictionary = {} | |
for key in keys: | |
sorted_dictionary[key] = dictionary[key] | |
return sorted_dictionary | |
# get batches from files | |
def get_batches(files, prefix=''): | |
# multiple files | |
if type(files) is list: | |
result = {} | |
for file in files: | |
data = {} | |
# unpickle batches | |
with open(prefix + file, 'rb') as fo: | |
data = pickle.load(fo, encoding='bytes') | |
for i in range(len(data[b'labels'])): | |
label = data[b'labels'][i] | |
if label not in result: | |
result[label] = [] | |
result[label].append(data[b'data'][i]) | |
return sort_key(result) | |
# a file | |
elif type(files) is str: | |
result = {} | |
data = {} | |
with open(prefix + files, 'rb') as fo: | |
data = pickle.load(fo, encoding='bytes') | |
for i in range(len(data[b'labels'])): | |
label = data[b'labels'][i] | |
if label not in result: | |
result[label] = [] | |
result[label].append(data[b'data'][i]) | |
return sort_key(result) | |
# Manhattan distance | |
def mdistance(x, y): | |
return np.abs(np.sum(x - y)) | |
# Euclidean distance | |
def edistance(x, y): | |
return np.sqrt(np.abs(np.sum((x-y) ** 2))) | |
# function that returns majority of a list | |
# Example : [1,2,1,1] -> 1 | |
def get_majority(data): | |
max_item = None | |
max_value = None | |
for item in set(data): | |
tmp = 0 | |
for datum in data: | |
if item is datum: | |
tmp += 1 | |
if max_value is None or max_value < tmp: | |
max_value = tmp | |
max_item = item | |
return max_item | |
# predict function | |
def predict(data, test, distance=mdistance): | |
min_label = [] | |
min_label_value = [] | |
for index in data: | |
for batch in data[index]: | |
d = distance(batch, test) | |
# always append till length of min_label does not equal K. | |
if len(min_label) < K: | |
min_label.append(index) | |
min_label_value.append(d) | |
elif max(min_label_value) > d: | |
i = np.argmax(min_label_value) | |
min_label[i] = index | |
min_label_value[i] = d | |
return get_majority(min_label) | |
# Hyperparameters | |
K = 5 | |
D = edistance | |
if __name__ == "__main__": | |
# get batches == train | |
files = ['1', '2', '3', '4', '5'] | |
images = get_batches(files, prefix='cifar-10-batches-py/data_batch_') | |
# get test data | |
test_images = get_batches('test_batch', prefix='cifar-10-batches-py/') | |
result = [] | |
for test_image_index in test_images: | |
cnt = 0 | |
for batch in test_images[test_image_index]: | |
label = predict(images, batch, distance=D) | |
result.append(label is test_image_index) | |
print("predict : %d, answer : %d"%(label, test_image_index)) | |
cnt += 1 | |
# check 100 images | |
# 10 (0 ~ 9) * 10 | |
# edit this value to adjust the number of test images | |
if cnt == 10: | |
break | |
# print result | |
result_np = np.array(result, dtype='float32') | |
print("Average : %f"%np.mean(result_np)) | |
# result | |
# the number of test images : 100 | |
# | |
# K = 3 | |
# using Manhattan distance function | |
# --- | |
# Average : 0.250000 | |
# --- | |
# | |
# using Euclidean distance function | |
# --- | |
# Average : 0.220000 | |
# --- | |
# | |
# K = 5 | |
# using Manhattan distance function | |
# --- | |
# Average : 0.250000 | |
# --- | |
# | |
# using Euclidean distance function | |
# --- | |
# Average : 0.200000 | |
# --- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment