Created
April 10, 2017 00:15
-
-
Save golbin/4869992498fd2103d80856d95dabb8b8 to your computer and use it in GitHub Desktop.
MNIST with CNN + RNN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Python 3, TensorFlow 1.0 | |
import os | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets("./mnist/data/", one_hot=True) | |
######### | |
# 옵션 설정 | |
###### | |
n_width = 28 # MNIST 이미지의 가로 크기, RNN의 input 갯수 | |
n_height = 28 # MNIST 이미지의 세로 크기, RNN의 step 수 | |
n_output = 10 # 0~9 | |
learning_rate = 0.001 | |
######### | |
# CNN | |
###### | |
def CNN2(input_X): | |
L1 = tf.contrib.layers.conv2d(input_X, 32, [3, 3], | |
normalizer_fn=tf.nn.dropout, | |
normalizer_params={'keep_prob': 0.8}) | |
L2 = tf.contrib.layers.max_pool2d(L1, [2, 2]) | |
L3 = tf.contrib.layers.conv2d(L2, 64, [3, 3], | |
normalizer_fn=tf.nn.dropout, | |
normalizer_params={'keep_prob': 0.8}) | |
L4 = tf.contrib.layers.max_pool2d(L3, [2, 2]) | |
L5 = tf.contrib.layers.flatten(L4) | |
L5 = tf.contrib.layers.fully_connected(L5, 256, | |
normalizer_fn=tf.contrib.layers.batch_norm) | |
return tf.contrib.layers.fully_connected(L5, n_output) | |
def CNN(input_X, keep_prob, training): | |
L1 = tf.layers.conv2d(input_X, 32, [3, 3], padding='same', activation=tf.nn.relu) | |
L1 = tf.layers.max_pooling2d(L1, [2, 2], strides=2) | |
L2 = tf.layers.conv2d(L1, 64, [3, 3], padding='same', activation=tf.nn.relu) | |
L2 = tf.layers.max_pooling2d(L2, [2, 2], strides=2) | |
L2 = tf.layers.dropout(L2, keep_prob, training) | |
L3 = tf.layers.conv2d(L2, 128, [3, 3], padding='same', activation=tf.nn.relu) | |
L3 = tf.layers.max_pooling2d(L3, [2, 2], strides=2) | |
L3 = tf.contrib.layers.flatten(L3) | |
L4 = tf.layers.dense(L3, n_height * n_width, activation=tf.nn.relu) | |
L4 = tf.layers.dropout(L4, keep_prob, training) | |
return tf.layers.dense(L4, n_output) | |
######### | |
# RNN | |
##### | |
def RNN(input_X, keep_prob, training): | |
cell = tf.contrib.rnn.BasicLSTMCell(128) | |
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob) | |
cell = tf.contrib.rnn.MultiRNNCell([cell] * 2) | |
outputs, states = tf.nn.dynamic_rnn(cell, input_X, dtype=tf.float32) | |
outputs = tf.contrib.layers.flatten(outputs) | |
outputs = tf.layers.dense(outputs, n_height * n_width, activation=tf.nn.relu) | |
outputs = tf.layers.dropout(outputs, keep_prob, training) | |
outputs = tf.layers.dense(outputs, n_output) | |
return outputs | |
global_step = tf.Variable(0, trainable=False, name="global_step") | |
is_training = tf.placeholder(tf.bool) | |
keep_prob = tf.placeholder(tf.float32) | |
CNN_X = tf.placeholder(tf.float32, [None, n_width, n_height, 1]) | |
RNN_X = tf.placeholder(tf.float32, [None, n_height, n_width]) | |
Y = tf.placeholder(tf.float32, [None, n_output]) | |
CNN_model = CNN(CNN_X, keep_prob, is_training) | |
CNN_cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=CNN_model, labels=Y)) | |
CNN_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(CNN_cost, global_step=global_step) | |
RNN_model = RNN(RNN_X, keep_prob, is_training) | |
RNN_cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=RNN_model, labels=Y)) | |
RNN_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(RNN_cost, global_step=global_step) | |
######### | |
# 신경망 모델 학습 | |
###### | |
sess = tf.Session() | |
saver = tf.train.Saver(tf.global_variables()) | |
ckpt = tf.train.get_checkpoint_state("./model") | |
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): | |
saver.restore(sess, ckpt.model_checkpoint_path) | |
else: | |
sess.run(tf.global_variables_initializer()) | |
batch_size = 100 | |
total_batch = int(mnist.train.num_examples/batch_size) | |
for epoch in range(30): | |
total_cost = 0 | |
for i in range(total_batch): | |
batch_xs, batch_ys = mnist.train.next_batch(batch_size) | |
CNN_batch_xs = batch_xs.reshape(-1, 28, 28, 1) | |
RNN_batch_xs = batch_xs.reshape(-1, 28, 28) | |
_, CNN_cost_val = sess.run([CNN_optimizer, CNN_cost], | |
feed_dict={CNN_X: CNN_batch_xs, Y: batch_ys, | |
keep_prob: 0.6, is_training: True}) | |
_, RNN_cost_val = sess.run([RNN_optimizer, RNN_cost], | |
feed_dict={RNN_X: RNN_batch_xs, Y: batch_ys, | |
keep_prob: 0.6, is_training: True}) | |
total_cost += CNN_cost_val + RNN_cost_val | |
print('Epoch:', '%04d' % epoch, \ | |
'Avg. cost =', '{:.4f}'.format(total_cost / total_batch)) | |
checkpoint_path = os.path.join("./model", "mnist.ckpt") | |
saver.save(sess, checkpoint_path) | |
print('최적화 완료!') | |
######### | |
# 결과 확인 | |
###### | |
CNN_is_correct = tf.equal(tf.argmax(CNN_model, 1), tf.argmax(Y, 1)) | |
CNN_accuracy = tf.reduce_mean(tf.cast(CNN_is_correct, tf.float32)) | |
RNN_is_correct = tf.equal(tf.argmax(RNN_model, 1), tf.argmax(Y, 1)) | |
RNN_accuracy = tf.reduce_mean(tf.cast(RNN_is_correct, tf.float32)) | |
is_correct = tf.equal(tf.argmax(tf.add(CNN_model, RNN_model), 1), tf.argmax(Y, 1)) | |
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32)) | |
acc_val, CNN_acc_val, RNN_acc_val = sess.run([accuracy, CNN_accuracy, RNN_accuracy], | |
feed_dict={CNN_X: mnist.test.images.reshape(-1, 28, 28, 1), | |
RNN_X: mnist.test.images.reshape(-1, 28, 28), | |
Y: mnist.test.labels, | |
keep_prob: 1, | |
is_training: False}) | |
print('정확도: %.4f (CNN: %.4f, RNN: %.4f)' % (acc_val, CNN_acc_val, RNN_acc_val)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment