Last active
December 29, 2016 02:28
-
-
Save YoiTaka/0320d2f6a6dd30f37042af3ff719b557 to your computer and use it in GitHub Desktop.
TensorFlowチュートリアル -MNIST For ML Beginners
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#TensorFlowのインポート | |
import tensorflow as tf | |
#数値計算用ライブラリのNumpuのインポート | |
import numpy as np | |
#tensorflow.contrib.learn.python.learn.datasets.mnist.pyの関数read_data_setsをインポート | |
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets | |
#MNISTデータの読み込み | |
mnist = read_data_sets("MNIST_data/", one_hot=True) | |
#入力用変数(データ数:Nne 入力データ:784個) | |
x = tf.placeholder("float", [None, 784]) | |
#重み用変数(0で初期化) | |
W = tf.Variable(tf.zeros([784, 10])) | |
#バイアス用変数(0で初期化) | |
b = tf.Variable(tf.zeros([10])) | |
#出力(xとWの内積を計算後にバイアスを足した後にソフトマックス回帰を実行) | |
#ソフトマックス関数を用いることで出力の合計値が1になる(=確率として表せる) | |
y = tf.nn.softmax(tf.matmul(x, W) + b) | |
#正解ラベル用の変数(10個のうち1つの値が1で残りが0) | |
y_ = tf.placeholder("float", [None, 10]) | |
#交差エントロピー誤差を計算(正解ラベルに対応するyの値が大きいほど0に近くなる) | |
cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) | |
#勾配降下法を用いて交差エントロピー誤差を最小化 | |
#学習率は0.001(ハイパーパラメータ) | |
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) | |
#セッションで作成したモデルを開始 | |
sess = tf.Session() | |
#初期化 | |
#tf.initialize_all_variables() -> tf.global_variables_initializer() | |
sess.run(tf.global_variables_initializer()) | |
#1000回の学習開始 | |
for i in range(1000): | |
#バッチ処理(trainから100個をランダムに取得) | |
batch_xs, batch_ys = mnist.train.next_batch(100) | |
#feet_dictではxとy_のplaceholderに値を置き換え | |
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) | |
#モデルを評価 | |
#argmax(y,1) : yの中で一番大きな値のインデックスを計算 | |
#tf.equal(,) : 値が一致していればTrue、していなければFalse | |
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) | |
#tf.cast() : bool型をfloat型にキャスト(True =1, False =0) | |
#tf.reduce_mean() : 平均値を計算 | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) | |
#精度計算(=MnistのTestを用いる) | |
print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment