Skip to content

Instantly share code, notes, and snippets.

@hadikazemi
Created September 28, 2017 13:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hadikazemi/bb6bb7a53341f13ed9c48710b40fe69b to your computer and use it in GitHub Desktop.
Save hadikazemi/bb6bb7a53341f13ed9c48710b40fe69b to your computer and use it in GitHub Desktop.
Fixed Code
import numpy as np
import bson
import tensorflow as tf
from StringIO import StringIO
import skimage
from skimage import io
# Load bson data and dump to TFRecord File
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
opts = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
train_filename = 'train.tfrecords'
writer = tf.python_io.TFRecordWriter(train_filename, options=opts)
data = bson.decode_file_iter(open('/home/hadi/Downloads/train_example.bson', 'rb'))
for c, d in enumerate(data):
product_id = d['_id']
category_id = d['category_id']
for e, pic in enumerate(d['imgs']): # loop through each picture
img_raw = pic['picture']
img_raw = skimage.io.imread(StringIO(img_raw)).astype(np.uint8).tostring()
label = category_id
feature = {'label': _int64_feature(label),
'image': _bytes_feature(img_raw)}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
import matplotlib.pyplot as plt
data_path = 'train.tfrecords'
with tf.Session() as sess:
filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
reader = tf.TFRecordReader(options=opts)
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.VarLenFeature(tf.int64),
'image': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image'], tf.uint8)
label = tf.cast(features['label'], tf.int32)
image = tf.reshape(image, [180, 180, 3])
images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1,
min_after_dequeue=10)
# Initialize all global and local variables
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
# Create a coordinator and run all QueueRunner object
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for batch_index in range(5):
img, lbl = sess.run([images, labels])
img = img.astype(np.uint8)
print('Batch %d and Batch shape is %s' % (batch_index + 1, img.shape))
for j in range(10):
plt.subplot(2, 5, j + 1)
plt.imshow(img[j, ...])
plt.show()
# Stop the threads
coord.request_stop()
# Wait for threads to stop
coord.join(threads)
sess.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment