Skip to content

Instantly share code, notes, and snippets.

@RustyNail
Created December 12, 2017 22:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RustyNail/4064ae24f864530258447a37eca9915a to your computer and use it in GitHub Desktop.
Save RustyNail/4064ae24f864530258447a37eca9915a to your computer and use it in GitHub Desktop.
mnist_tutorial.py
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # TensorFlowをインポート
from tensorflow.examples.tutorials.mnist import input_data # MNISTデータのロード
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
train_image = tf.placeholder(tf.float32, [None, 784]) # 28x28pxの訓練画像を格納する変数
train_label = tf.placeholder(tf.float32, [None, 10]) # 正解データのラベルを格納する変数
W = tf.Variable(tf.zeros([784, 10])) # 重み(初期値0)
b = tf.Variable(tf.zeros([10])) # バイアス(初期値0)
y = tf.nn.softmax(tf.matmul(train_image, W) + b) # ソフトマックス回帰を実行
learning_rate = 0.01 # 学習率
learning_count = 1000 # 学習回数
cross_entropy = -tf.reduce_sum(train_label * tf.log(y)) # 交差エントロピー誤差
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy) # 勾配降下法で交差エントロピー誤差が最小となるようyを最適化
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
for i in range(learning_count):
batch_image, batch_label = mnist.train.next_batch(100) # ランダムに抽出した100個の訓練データ(画像と対応するラベル)を選択
sess.run(train_step, feed_dict = { train_image: batch_image, train_label: batch_label }) # 学習(train_stepを実行)
# 正答率を算出する処理
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(train_label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
test_images = mnist.test.images
test_labels = mnist.test.labels
print(sess.run(accuracy, feed_dict = { train_image: test_images, train_label: test_labels }))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment