Skip to content

Instantly share code, notes, and snippets.

@koshian2
Last active June 1, 2018 23:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/30cdbefb878ee834b952803fc7501206 to your computer and use it in GitHub Desktop.
Save koshian2/30cdbefb878ee834b952803fc7501206 to your computer and use it in GitHub Desktop.
CIFAR-10
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def convolution_filter(img):
# 縦方向のフィルター
sobel_v = np.array([[-1,0,1], [-2,0,2], [-1,0,1]])
# 横方向のフィルター
sobel_h = np.array([[-1,-2,-1], [0,0,0], [1,2,1]])
# カーネルサイズ
kernel_size = 3
# 出力画像
out_img = np.zeros((img.shape[0]-kernel_size+1, img.shape[1]-kernel_size+1, 3))
for i in range(kernel_size-1, img.shape[0]-kernel_size+1):
for j in range(kernel_size-1, img.shape[1]-kernel_size+1):
# スライス
img_slice = img[(i-2):(i+1), (j-2):(j+1), 0:3]
# 畳み込み
conv_v = np.sum(img_slice * sobel_v, axis=(1,2))
conv_h = np.sum(img_slice * sobel_h, axis=(1,2))
# 代入
out_img[i, j, :] = np.sqrt(conv_v**2 + conv_h**2)
return out_img
# 画像を3回畳み込み
def conv_sample(img):
titles = ["Original", "1st conv", "2nd conv", "3rd conv"]
fig = plt.figure(figsize = (8, 8))
fig.subplots_adjust(hspace=0.2, wspace=0.2)
for i in range(4):
ax = fig.add_subplot(2, 2, i+1)
view_img = img / np.max(img) #表示用にスケール調整
ax.imshow(view_img)
ax.set_title(titles[i] + " " + str(view_img.shape))
# 畳み込み
if i!=3: img = convolution_filter(img)
plt.show()
# 画像をプーリング(パディング処理は未実装)
def pooling(img):
#3x3でプーリング、strideも3
kernel_size = 3
out_img = np.zeros((int(img.shape[0]/kernel_size), int(img.shape[1]/kernel_size), 3))
print(img.shape, out_img.shape)
for i in range(out_img.shape[0]):
for j in range(out_img.shape[1]):
img_slice = img[(i*kernel_size):((i+1)*kernel_size), (j*kernel_size):((j+1)*kernel_size), :]
out_img[i, j, :] = np.max(img_slice, axis=(1,2))
return out_img
# 画像を3回畳み込み+プーリング
def conv_pool_sample(img):
titles = ["Original", "1st conv-pool", "2nd conv-pool", "3rd conv-pool"]
fig = plt.figure(figsize = (8, 8))
fig.subplots_adjust(hspace=0.2, wspace=0.2)
for i in range(4):
ax = fig.add_subplot(2, 2, i+1)
view_img = img / np.max(img) #表示用にスケール調整
ax.imshow(view_img)
ax.set_title(titles[i] + " " + str(view_img.shape))
# 畳み込み
if i!=3:
img = convolution_filter(img)
img = pooling(img)
plt.show()
if __name__ == "__main__":
# 画像の読み込み
img = np.array(Image.open("lenna.png"))
# 畳み込みサンプル
conv_sample(img)
# 畳み込み+プーリング
#conv_pool_sample(img)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment