Skip to content

Instantly share code, notes, and snippets.

@matsuken92
Last active August 29, 2015 14:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matsuken92/ad6d48574b92b54cefb6 to your computer and use it in GitHub Desktop.
Save matsuken92/ad6d48574b92b54cefb6 to your computer and use it in GitHub Desktop.
手書き数字をpythonでもてあそぶ その2(識別する) ref: http://qiita.com/kenmatsu4/items/2d21466078917c200033
C = \{0, 1, 2, 3, 4, 5, 6, 7, 8, 9\}
y_i= (y_1, y_2,...,y_{784})       (i=0,1,...,9)
\hat{y}_i = \frac{1}{n_i}\sum_{j=1}^{n_i} y_j
x_j= (x_1, x_2,...,x_{784})
{\rm argmin}_i{({\rm distance}_i)} = {\rm argmin}_i{(\|\hat{y}_i - x_j\|)}
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from collections import defaultdict
image_dict = dict()
def add_image(label, image_vector):
vec = np.array(image_vector)
if label in image_dict:
image_dict[label] += vec
else:
image_dict[label] = vec
return image_dict
label_dd = defaultdict(int)
def count_label(data):
for d in data:
label_dd[d[0]] += 1
return label_dd
def plot_digits(X, Y, Z, size_x, size_y, counter, title):
plt.subplot(size_x, size_y, counter)
plt.title(title)
plt.xlim(0,27)
plt.ylim(0,27)
plt.pcolor(X, Y, Z)
plt.gray()
plt.tick_params(labelbottom="off")
plt.tick_params(labelleft="off")
size = 28
raw_data= np.loadtxt('train_master.csv',delimiter=',',skiprows=1)
# draw digit images
plt.figure(figsize=(11, 6))
# data aggregation
for i in range(len(raw_data)):
add_image(raw_data[i,0],raw_data[i,1:785])
count_dict = count_label(raw_data)
standardized_digit_dict = dict() # 代表値を格納する辞書オブジェクト
count = 0
for key in image_dict.keys():
count += 1
X, Y = np.meshgrid(range(size),range(size))
num = label_dd[key]
Z = image_dict[key].reshape(size,size)/num
Z = Z[::-1,:]
standardized_digit_dict[int(key)] = Z
plot_digits(X, Y, standardized_digit_dict[int(key)], 2, 5, count, "")
plt.show()
test_data= np.loadtxt('test_small.csv',delimiter=',',skiprows=1)
# compare 1 tested digit vs average digits with norm
plt.figure(figsize=(10, 9))
for i in range(1): # 最初の1つだけ試してみる
result_dict = defaultdict(float)
X, Y = np.meshgrid(range(size),range(size))
Z = test_data[i].reshape(size,size)
Z = Z[::-1,:]
flat_Z = Z.flatten()
plot_digits(X, Y, Z, 3, 4, 1, "tested")
count = 0
for key in standardized_digit_dict.keys():
count += 1
X1 = standardized_digit_dict[key]
flat_X1 = standardized_digit_dict[key].flatten()
norm = np.linalg.norm(flat_X1 - flat_Z) # 各代表値と識別対象データとの距離の導出
plot_digits(X, Y, X1, 3, 4, (1+count), "d=%.3f"% norm)
plt.show()
# recognize digits
plt.figure(figsize=(15, 130))
for i in range(len(test_data)):
result_dict = defaultdict(float)
X, Y = np.meshgrid(range(size),range(size))
tested = test_data[i].reshape(size,size)
tested = tested[::-1,:]
flat_tested = tested.flatten()
norm_list=[]
count = 0
for key in standardized_digit_dict.keys():
count += 1
sdd = standardized_digit_dict[key]
flat_sdd = sdd.flatten()
norm = np.linalg.norm(flat_sdd - flat_tested)
norm_list.append((key, norm))
norm_list = np.array(norm_list)
min_result = norm_list[np.argmin(norm_list[:,1])]
plot_digits(X, Y, tested, 40, 5, i+1, "l=%d, n=%d" % (min_result[0], min_result[1]))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment