Skip to content

Instantly share code, notes, and snippets.

@golbin
Created April 10, 2017 00:15
Show Gist options
  • Save golbin/4869992498fd2103d80856d95dabb8b8 to your computer and use it in GitHub Desktop.
Save golbin/4869992498fd2103d80856d95dabb8b8 to your computer and use it in GitHub Desktop.
MNIST with CNN + RNN
# Python 3, TensorFlow 1.0
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./mnist/data/", one_hot=True)
#########
# 옵션 설정
######
n_width = 28 # MNIST 이미지의 가로 크기, RNN의 input 갯수
n_height = 28 # MNIST 이미지의 세로 크기, RNN의 step 수
n_output = 10 # 0~9
learning_rate = 0.001
#########
# CNN
######
def CNN2(input_X):
L1 = tf.contrib.layers.conv2d(input_X, 32, [3, 3],
normalizer_fn=tf.nn.dropout,
normalizer_params={'keep_prob': 0.8})
L2 = tf.contrib.layers.max_pool2d(L1, [2, 2])
L3 = tf.contrib.layers.conv2d(L2, 64, [3, 3],
normalizer_fn=tf.nn.dropout,
normalizer_params={'keep_prob': 0.8})
L4 = tf.contrib.layers.max_pool2d(L3, [2, 2])
L5 = tf.contrib.layers.flatten(L4)
L5 = tf.contrib.layers.fully_connected(L5, 256,
normalizer_fn=tf.contrib.layers.batch_norm)
return tf.contrib.layers.fully_connected(L5, n_output)
def CNN(input_X, keep_prob, training):
L1 = tf.layers.conv2d(input_X, 32, [3, 3], padding='same', activation=tf.nn.relu)
L1 = tf.layers.max_pooling2d(L1, [2, 2], strides=2)
L2 = tf.layers.conv2d(L1, 64, [3, 3], padding='same', activation=tf.nn.relu)
L2 = tf.layers.max_pooling2d(L2, [2, 2], strides=2)
L2 = tf.layers.dropout(L2, keep_prob, training)
L3 = tf.layers.conv2d(L2, 128, [3, 3], padding='same', activation=tf.nn.relu)
L3 = tf.layers.max_pooling2d(L3, [2, 2], strides=2)
L3 = tf.contrib.layers.flatten(L3)
L4 = tf.layers.dense(L3, n_height * n_width, activation=tf.nn.relu)
L4 = tf.layers.dropout(L4, keep_prob, training)
return tf.layers.dense(L4, n_output)
#########
# RNN
#####
def RNN(input_X, keep_prob, training):
cell = tf.contrib.rnn.BasicLSTMCell(128)
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)
cell = tf.contrib.rnn.MultiRNNCell([cell] * 2)
outputs, states = tf.nn.dynamic_rnn(cell, input_X, dtype=tf.float32)
outputs = tf.contrib.layers.flatten(outputs)
outputs = tf.layers.dense(outputs, n_height * n_width, activation=tf.nn.relu)
outputs = tf.layers.dropout(outputs, keep_prob, training)
outputs = tf.layers.dense(outputs, n_output)
return outputs
global_step = tf.Variable(0, trainable=False, name="global_step")
is_training = tf.placeholder(tf.bool)
keep_prob = tf.placeholder(tf.float32)
CNN_X = tf.placeholder(tf.float32, [None, n_width, n_height, 1])
RNN_X = tf.placeholder(tf.float32, [None, n_height, n_width])
Y = tf.placeholder(tf.float32, [None, n_output])
CNN_model = CNN(CNN_X, keep_prob, is_training)
CNN_cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=CNN_model, labels=Y))
CNN_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(CNN_cost, global_step=global_step)
RNN_model = RNN(RNN_X, keep_prob, is_training)
RNN_cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=RNN_model, labels=Y))
RNN_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(RNN_cost, global_step=global_step)
#########
# 신경망 모델 학습
######
sess = tf.Session()
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state("./model")
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
batch_size = 100
total_batch = int(mnist.train.num_examples/batch_size)
for epoch in range(30):
total_cost = 0
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
CNN_batch_xs = batch_xs.reshape(-1, 28, 28, 1)
RNN_batch_xs = batch_xs.reshape(-1, 28, 28)
_, CNN_cost_val = sess.run([CNN_optimizer, CNN_cost],
feed_dict={CNN_X: CNN_batch_xs, Y: batch_ys,
keep_prob: 0.6, is_training: True})
_, RNN_cost_val = sess.run([RNN_optimizer, RNN_cost],
feed_dict={RNN_X: RNN_batch_xs, Y: batch_ys,
keep_prob: 0.6, is_training: True})
total_cost += CNN_cost_val + RNN_cost_val
print('Epoch:', '%04d' % epoch, \
'Avg. cost =', '{:.4f}'.format(total_cost / total_batch))
checkpoint_path = os.path.join("./model", "mnist.ckpt")
saver.save(sess, checkpoint_path)
print('최적화 완료!')
#########
# 결과 확인
######
CNN_is_correct = tf.equal(tf.argmax(CNN_model, 1), tf.argmax(Y, 1))
CNN_accuracy = tf.reduce_mean(tf.cast(CNN_is_correct, tf.float32))
RNN_is_correct = tf.equal(tf.argmax(RNN_model, 1), tf.argmax(Y, 1))
RNN_accuracy = tf.reduce_mean(tf.cast(RNN_is_correct, tf.float32))
is_correct = tf.equal(tf.argmax(tf.add(CNN_model, RNN_model), 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
acc_val, CNN_acc_val, RNN_acc_val = sess.run([accuracy, CNN_accuracy, RNN_accuracy],
feed_dict={CNN_X: mnist.test.images.reshape(-1, 28, 28, 1),
RNN_X: mnist.test.images.reshape(-1, 28, 28),
Y: mnist.test.labels,
keep_prob: 1,
is_training: False})
print('정확도: %.4f (CNN: %.4f, RNN: %.4f)' % (acc_val, CNN_acc_val, RNN_acc_val))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment