Skip to content

Instantly share code, notes, and snippets.

@oishimilk
Last active June 15, 2021 15:45
Show Gist options
  • Save oishimilk/abe2b0dc79266034acdf64da6a805e06 to your computer and use it in GitHub Desktop.
Save oishimilk/abe2b0dc79266034acdf64da6a805e06 to your computer and use it in GitHub Desktop.
史上最悪レベルの畳み込み
#!/usr/bin/env python
"""
生み出してしまった史上最悪レベルの畳み込みです。
できるだけスカラーを扱うようにするという制約の下で作成しています。
高速化の真逆のことをやっています。
"""
import numpy as np
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
def conv2d(input_tensor: np.ndarray, kernel: np.ndarray, strides: int = 1, padding: int = 0) -> np.ndarray:
"""
二次元の畳み込み演算を行います。
"""
if input_tensor.shape[1] % kernel.shape[0] or input_tensor.shape[2] % kernel.shape[1]:
raise ValueError("割れません。")
if not kernel.shape[0] % 2 or not kernel.shape[1] % 2:
raise ValueError("カーネルの大きさが偶数です。")
# 出力サイズの設定
oh = (input_tensor.shape[1] + 2 * padding - kernel.shape[0]) // strides + 1
ow = (input_tensor.shape[2] + 2 * padding - kernel.shape[1]) // strides + 1
result = np.zeros([input_tensor.shape[0], oh, ow, kernel.shape[3]])
# ゼロ埋め
if padding > 0:
zeros = np.zeros([input_tensor.shape[0], input_tensor.shape[1] + 2 * padding, input_tensor.shape[2] + 2 * padding, input_tensor.shape[3]], dtype=input_tensor.dtype)
zeros[:, padding:-padding, padding:-padding, :] = input_tensor
input_tensor = zeros
# バッチ
for b in range(result.shape[0]):
# カーネルの入力チャンネル
for c in range(kernel.shape[2]):
# 結果の幅
for i in range(0, result.shape[2]):
# 結果の高さ
for j in range(0, result.shape[1]):
# カーネルの幅
for w in range(kernel.shape[1]):
# カーネルの高さ
for h in range(kernel.shape[0]):
# カーネルの出力チャンネル
for m in range(kernel.shape[3]):
result[b, i, j, m] += input_tensor[b, strides * i + h, strides * j + w, c] * kernel[h, w, c, m]
return result
def maxpool2d(input_tensor: np.ndarray, pool_size: int = 2, strides: int = 2) -> np.ndarray:
"""
最大値プーリングを行います。
"""
# 出力サイズの設定
oh = (input_tensor.shape[1] - pool_size) // strides + 1
ow = (input_tensor.shape[2] - pool_size) // strides + 1
result = np.zeros([input_tensor.shape[0], oh, ow, input_tensor.shape[3]])
# バッチ
for b in range(input_tensor.shape[0]):
# チャンネル
for c in range(input_tensor.shape[3]):
# 結果の高さ
for i in range(oh):
# 結果の幅
for j in range(ow):
result[b, i, j, c] = np.max(input_tensor[b, (i * strides):(i * strides + pool_size), (j * strides):(j * strides + pool_size), c])
return result
def main() -> None:
"""
ここから始まります。
"""
# テストデータ
test = np.random.rand(1, 6, 6, 3) # batch, height, width, channels
kernel = np.random.rand(3, 3, 3, 3) # filter_height, filter_width, in_channels, out_channels
# TensorFlow
try:
tf.disable_eager_execution()
tf.disable_v2_behavior()
except AttributeError:
pass
# テンソルの入れ物
ph_test = tf.placeholder(dtype=tf.float32)
ph_kernel = tf.placeholder(dtype=tf.float32)
# 計算グラフ
convolved_tensor_tf_graph = tf.nn.conv2d(ph_test, filter=ph_kernel, strides=1, padding="VALID", data_format='NHWC')
maxpooled_tensor_tf_graph = tf.nn.max_pool2d(ph_test, 2, 2, "VALID") # input, ksize, strides, padding
with tf.Session() as sess:
convolved_tensor_tf_np = sess.run(convolved_tensor_tf_graph, feed_dict={ph_test: test, ph_kernel: kernel})
maxpooled_tensor_tf_np = sess.run(maxpooled_tensor_tf_graph, feed_dict={ph_test: test})
# 自前実装
convolved_tensor_my_implementation = conv2d(test, kernel, strides=1, padding=0)
maxpooled_tensor_my_implementation = maxpool2d(test, pool_size=2, strides=2)
# 確認
print("畳み込み")
print("TensorFlow の答え: ", convolved_tensor_tf_np.shape)
print(convolved_tensor_tf_np)
print("自前実装の答え: ", convolved_tensor_my_implementation.shape)
print(convolved_tensor_my_implementation)
print("誤差")
print(convolved_tensor_tf_np - convolved_tensor_my_implementation)
print("プーリング")
print("TensorFlow の答え: ", maxpooled_tensor_tf_np.shape)
print(maxpooled_tensor_tf_np)
print("自前実装の答え: ", maxpooled_tensor_my_implementation.shape)
print(maxpooled_tensor_my_implementation)
print("誤差")
print(maxpooled_tensor_tf_np - maxpooled_tensor_my_implementation)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment