Skip to content

Instantly share code, notes, and snippets.

@leechanwoo
Created August 24, 2017 03:04
Show Gist options
  • Save leechanwoo/4d0ee695eb0daa581571392d9e9c23b6 to your computer and use it in GitHub Desktop.
Save leechanwoo/4d0ee695eb0daa581571392d9e9c23b6 to your computer and use it in GitHub Desktop.
CNN part1
import tensorflow as tf
import matplotlib.pyplot as plt
import os
%matplotlib inline
images = "dataset/test_dataset_png/"
image_dir = os.path.join(os.getcwd(), images)
imagenames = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]
label = "dataset/test_dataset_csv/label.csv"
labelname = [os.path.join(os.getcwd(), label)]
imagename_queue = tf.train.string_input_producer(imagenames)
labelname_queue = tf.train.string_input_producer(labelname)
img_reader = tf.WholeFileReader()
label_reader = tf.TextLineReader()
_, image = img_reader.read(imagename_queue)
_, label = label_reader.read(labelname_queue)
decoded_img = tf.image.decode_png(image)
reshaped_img = tf.reshape(decoded_img, shape=[61, 49, 1])
reshaped_img = tf.cast(reshaped_img, tf.float32)
decoded_label = tf.decode_csv(label, record_defaults=[[0]])
x, y_ = tf.train.batch([reshaped_img, decoded_label], 10)
conv1 = tf.layers.conv2d(x, filters=10, kernel_size=[3, 3], padding="SAME")
conv2 = tf.layers.conv2d(conv1, filters=10, kernel_size=[3, 3], padding="SAME")
# pool2 = tf.layers.max_pooling2d(conv2, pool_size=[2, 2], strides=[2, 2])
conv3 = tf.layers.conv2d(conv2, filters=10, kernel_size=[3, 3], padding="SAME")
# pool3 = tf.layers.max_pooling2d(conv3, pool_size=[2, 2], strides=[2, 2])
conv4 = tf.layers.conv2d(conv3, filters=10, kernel_size=[3, 3], padding="SAME")
# pool4 = tf.layers.max_pooling2d(conv4, pool_size=[2, 2], strides=[2, 2])
flat = tf.reshape(conv4, shape=[-1, 61*49*10])
fc1 = tf.layers.dense(flat, 5000)
fc2 = tf.layers.dense(fc1, 1000)
out = tf.layers.dense(fc2, 3)
with tf.Session() as sess:
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)
for i in range(100):
age = sess.run(decoded_label)
print(age)
coord.request_stop()
coord.join(thread)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment