Skip to content

Instantly share code, notes, and snippets.

@YoiTaka
Last active December 29, 2016 02:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YoiTaka/0320d2f6a6dd30f37042af3ff719b557 to your computer and use it in GitHub Desktop.
Save YoiTaka/0320d2f6a6dd30f37042af3ff719b557 to your computer and use it in GitHub Desktop.
TensorFlowチュートリアル -MNIST For ML Beginners
#TensorFlowのインポート
import tensorflow as tf
#数値計算用ライブラリのNumpuのインポート
import numpy as np
#tensorflow.contrib.learn.python.learn.datasets.mnist.pyの関数read_data_setsをインポート
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
#MNISTデータの読み込み
mnist = read_data_sets("MNIST_data/", one_hot=True)
#入力用変数(データ数:Nne 入力データ:784個)
x = tf.placeholder("float", [None, 784])
#重み用変数(0で初期化)
W = tf.Variable(tf.zeros([784, 10]))
#バイアス用変数(0で初期化)
b = tf.Variable(tf.zeros([10]))
#出力(xとWの内積を計算後にバイアスを足した後にソフトマックス回帰を実行)
#ソフトマックス関数を用いることで出力の合計値が1になる(=確率として表せる)
y = tf.nn.softmax(tf.matmul(x, W) + b)
#正解ラベル用の変数(10個のうち1つの値が1で残りが0)
y_ = tf.placeholder("float", [None, 10])
#交差エントロピー誤差を計算(正解ラベルに対応するyの値が大きいほど0に近くなる)
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
#勾配降下法を用いて交差エントロピー誤差を最小化
#学習率は0.001(ハイパーパラメータ)
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#セッションで作成したモデルを開始
sess = tf.Session()
#初期化
#tf.initialize_all_variables() -> tf.global_variables_initializer()
sess.run(tf.global_variables_initializer())
#1000回の学習開始
for i in range(1000):
#バッチ処理(trainから100個をランダムに取得)
batch_xs, batch_ys = mnist.train.next_batch(100)
#feet_dictではxとy_のplaceholderに値を置き換え
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#モデルを評価
#argmax(y,1) : yの中で一番大きな値のインデックスを計算
#tf.equal(,) : 値が一致していればTrue、していなければFalse
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
#tf.cast() : bool型をfloat型にキャスト(True =1, False =0)
#tf.reduce_mean() : 平均値を計算
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
#精度計算(=MnistのTestを用いる)
print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment