Created
August 21, 2017 10:53
-
-
Save jeongukjae/d519eaac389abde660a894a25bae0d3e to your computer and use it in GitHub Desktop.
image classifier using nearest neighbor and cifar 10 dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# image classifier using nearest neighbor and cifar 10 dataset | |
import numpy as np | |
import pickle | |
# sort dictionary by key | |
def sort_key(dictionary): | |
keys = list(dictionary.keys()) | |
keys.sort() | |
sorted_dictionary = {} | |
for key in keys: | |
sorted_dictionary[key] = dictionary[key] | |
return sorted_dictionary | |
# get batches from files | |
def get_batches(files, prefix=''): | |
# multiple files | |
if type(files) is list: | |
result = {} | |
for file in files: | |
data = {} | |
# unpickle batches | |
with open(prefix + file, 'rb') as fo: | |
data = pickle.load(fo, encoding='bytes') | |
for i in range(len(data[b'labels'])): | |
label = data[b'labels'][i] | |
if label not in result: | |
result[label] = [] | |
result[label].append(data[b'data'][i]) | |
return sort_key(result) | |
# a file | |
elif type(files) is str: | |
result = {} | |
data = {} | |
with open(prefix + files, 'rb') as fo: | |
data = pickle.load(fo, encoding='bytes') | |
for i in range(len(data[b'labels'])): | |
label = data[b'labels'][i] | |
if label not in result: | |
result[label] = [] | |
result[label].append(data[b'data'][i]) | |
return sort_key(result) | |
# Manhattan distance | |
def mdistance(x, y): | |
return np.abs(np.sum(x - y)) | |
# Euclidean distance | |
def edistance(x, y): | |
return np.sqrt(np.abs(np.sum((x-y) ** 2))) | |
# predict function | |
def predict(data, test, distance=mdistance): | |
min_label = 0 | |
min_label_value = -1 | |
for index in data: | |
for batch in data[index]: | |
d = distance(batch, test) | |
if min_label_value is -1 or min_label_value > d: | |
min_label = index | |
min_label_value = d | |
return min_label | |
if __name__ == "__main__": | |
# get batches == train | |
files = ['1', '2', '3', '4', '5'] | |
images = get_batches(files, prefix='cifar-10-batches-py/data_batch_') | |
# get test data | |
test_images = get_batches('test_batch', prefix='cifar-10-batches-py/') | |
result = [] | |
for test_image_index in test_images: | |
cnt = 0 | |
for batch in test_images[test_image_index]: | |
label = predict(images, batch, distance=edistance) | |
result.append(label is test_image_index) | |
print("predict : %d, answer : %d"%(label, test_image_index)) | |
cnt += 1 | |
# check 100 images | |
# 10 (0 ~ 9) * 10 | |
# edit this value to adjust the number of test images | |
if cnt == 10: | |
break | |
# print result | |
result_np = np.array(result, dtype='float32') | |
print("Average : %f"%np.mean(result_np)) | |
# result | |
# the number of test images : 100 | |
# | |
# using Manhattan distance function | |
# --- | |
# Average : 0.200000 | |
# --- | |
# | |
# using Euclidean distance function | |
# --- | |
# Average : 0.270000 | |
# --- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment