Skip to content

Instantly share code, notes, and snippets.

@leechanwoo
Last active December 5, 2017 08:46
Show Gist options
  • Save leechanwoo/007b4794b779ae9956ae04d97f694bd3 to your computer and use it in GitHub Desktop.
Save leechanwoo/007b4794b779ae9956ae04d97f694bd3 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import matplotlib.pyplot as plt
import os
tf.reset_default_graph()
image_dir = './tfrecord_dataset/images_png/
label_name = './tfrecord_dataset/label_csv/label.csv'
image_names = [os.path.join(image_dir, n) for n in os.listdir(image_dir)]
img_name_queue = tf.train.string_input_producer(image_names, seed=7777)
label_name_queue = tf.train.string_input_producer([label_name], seed=7777)
img_reader = tf.WholeFileReader()
tf.TextLineReader()
key, value = img_reader.read(img_name_queue)
img_png = tf.image.decode_png(value)
img_png = tf.reduce_mean(img_png, axis=-1)
with tf.Session() as sess:
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)
_img, _label = sess.run([img_png, label_csv])
print(_label)
print(_img.shape)
plt.imshow(_img)
plt.show()
coord.request_stop()
coord.join(thread)
import tensorflow as tf
import matplotlib.pyplot as plt
tf.reset_default_graph()
def parser(serialized_example):
feature = {
'age': tf.FixedLenFeature([1], tf.int64),
'img': tf.FixedLenFeature([61*49], tf.int64)
}
parsed_feature = tf.parse_single_example(serialized_example, feature)
age = tf.cast(parsed_feature['age'], tf.int32)
img = tf.cast(parsed_feature['img'], tf.float32)
return age, img
dataset_dir = './cnn_dataset/face_train.tfrecord'
dataset = tf.contrib.data.TFRecordDataset(dataset_dir).map(parser)
dataset = dataset.batch(32)
dataset = dataset.shuffle(7777)
itr = dataset.make_one_shot_iterator()
age, img = itr.get_next()
img = tf.reshape(img, [-1, 61, 49])
with tf.Session() as sess:
_age, _img = sess.run([age, img])
for i in range(32):
print(_age[i])
plt.imshow(_img[i])
plt.show()
import tensorflow as tf
import matplotlib.pyplot as plt
import os
tf.reset_default_graph()
image_dir = './tfrecord_dataset/images_png/'
label_name = './tfrecord_dataset/label_csv/label.csv'
image_names = [os.path.join(image_dir, n) for n in os.listdir(image_dir)]
image_names = sorted(image_names)
img_name_queue = tf.train.string_input_producer(image_names, num_epochs=1, shuffle=False)
label_name_queue = tf.train.string_input_producer([label_name], num_epochs=1, shuffle=False)
img_reader = tf.WholeFileReader()
text_reader = tf.TextLineReader()
img_key, img_value = img_reader.read(img_name_queue)
txt_key, txt_value = text_reader.read(label_name_queue)
img_png = tf.image.decode_png(img_value)
img_png = tf.reduce_mean(img_png, axis=-1)
img_png = tf.reshape(img_png, [-1])
txt_csv = tf.decode_csv(txt_value, record_defaults=[[0]])
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)
face_train_dir = './cnn_dataset/face_train.tfrecord'
face_test_dir = './cnn_dataset/face_test.tfrecord'
train_writer = tf.python_io.TFRecordWriter(face_train_dir)
for i in range(9999999999999):
try:
_img, _age = sess.run([img_png, txt_csv])
example = tf.train.Example()
example.features.feature['age'].int64_list.value.append(_age[0])
example.features.feature['img'].int64_list.value.extend(_img.tolist())
train_writer.write(example.SerializeToString())
print('{}th record is written'.format(i))
except tf.errors.OutOfRangeError:
print('end of record')
break
train_writer.close()
coord.request_stop()
coord.join(thread)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment