Skip to content

Instantly share code, notes, and snippets.

@hiwonjoon
Created January 27, 2017 17:35
Show Gist options
  • Save hiwonjoon/6afbbf390af77293ea259a11c2cac079 to your computer and use it in GitHub Desktop.
Save hiwonjoon/6afbbf390af77293ea259a11c2cac079 to your computer and use it in GitHub Desktop.
Feature Extrcation and Write TF Record example
import numpy as np
import os
import tensorflow as tf
import vgg
import vgg_preprocessing
from pycocotools.coco import COCO
slim = tf.contrib.slim
LOG_DIR = './log/fe'
SUMMARY_PERIOD = 100
BATCH_SIZE = 128
IMAGE_HEIGHT = vgg.vgg_19.default_image_size
IMAGE_WIDTH = vgg.vgg_19.default_image_size
DATA_DIR='/home/nine/datasets/coco'
MODEL_PATH = './vgg_19.ckpt'
tf.set_random_seed(0)
# Dataset Reader / Writer
# TODO : change file format to TFRecord for better queue performance.
coco = COCO('%s/annotations/instances_%s.json'%(DATA_DIR,'train2014'))
annIds = coco.getAnnIds()
filenames = []
bboxes = np.zeros((len(annIds),4),np.int32)
cats = np.zeros((len(annIds),),np.int32)
anno_count = 0
for annId in annIds :
anno = coco.loadAnns(annId)[0]
img = coco.loadImgs(anno['image_id'])[0]
filename = '%s/train2014/%s'%(DATA_DIR,img['file_name'])
bbox = anno['bbox']
cat = anno['category_id']
if( anno['area'] <= 200.0 ) : continue
filenames.append(filename)
bboxes[anno_count] = bbox
cats[anno_count] = cat
anno_count += 1
print("choose %d regions among %d regions"%(len(annIds),anno_count))
bboxes = bboxes[:anno_count]
cats = cats[:anno_count]
annId, filename, bbox, cat = tf.train.slice_input_producer([annIds,filenames,bboxes,cats],num_epochs=1,shuffle=False)
file_contents = tf.read_file(filename)
whole_image = tf.image.decode_jpeg(file_contents, channels=3)
cropped_image = tf.image.crop_to_bounding_box(whole_image,bbox[1],bbox[0],bbox[3],bbox[2])
# TODO : ignore aspect preserving?
# cropped_image = tf.image.resize_images(cropped_image,[256,256])
preprocessed_image = vgg_preprocessing.preprocess_image(
cropped_image, IMAGE_HEIGHT, IMAGE_WIDTH, is_training = False)
# Build image batch
ids, images, labels = tf.train.batch(
[annId, preprocessed_image, cat],
# whole_images, cropped_images, images = tf.train.batch(
# [tf.image.resize_images(whole_image,[224,224]),
# tf.image.resize_images(cropped_image,[224,224]),
# preprocessed_image], #For debugging.
batch_size=BATCH_SIZE,
num_threads=2,
capacity=5*BATCH_SIZE)
#tf.summary.image('whole_images',whole_images,max_outputs=10)
#tf.summary.image('cropped_images',cropped_images,max_outputs=10)
tf.summary.image('images',images,max_outputs=10)
# Define base model
with slim.arg_scope(vgg.vgg_arg_scope()): # For weight decay.
logits, end_points = vgg.vgg_19(images,is_training=False)
fc7 = end_points['vgg_19/fc7']
variables_to_restore = slim.get_variables_to_restore()
# Saving Operations
saver = tf.train.Saver(max_to_keep = 5)
# Summary Operations
summary_op = tf.summary.merge_all()
# Queue ,Threads and Summary Writer
sess = tf.Session()
summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph)
coord = tf.train.Coordinator()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
saver.restore(sess, MODEL_PATH)
#assert( len(sess.run(tf.report_uninitialized_variables())) == 0 )
writer = tf.python_io.TFRecordWriter(os.path.join('./feature_extracted.tfrecords'))
# Start Queueing
threads = tf.train.start_queue_runners(coord=coord,sess=sess)
try:
def _int64_feature(value):
value = value if type(value) == list else [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
value = value if type(value) == list else [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _float_feature(value):
value = value if type(value) == list else [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
step = 0
while(True): # Slice Input producer will throw OutOfRange exception
if coord.should_stop() :
break
ids_val, fc7_val, summary_str = sess.run([ids, fc7,summary_op])
if( step % SUMMARY_PERIOD == 0 ) :
summary_writer.add_summary(summary_str,step)
# Write in TF Record formats
for i,id_val in enumerate(ids_val) :
anno = coco.loadAnns(int(id_val))[0]
img = coco.loadImgs(anno['image_id'])[0]
filename = img['file_name']
bbox = anno['bbox']
cat = anno['category_id']
instance = tf.train.Example(features=tf.train.Features(feature={
'annId' : _int64_feature(id_val),
'vgg19' : _float_feature(np.ndarray.tolist(fc7_val[i,0,0,:])),
'filename': _bytes_feature(filename.encode('ascii','ignore')),
'bbox' : _float_feature(bbox),
'cat' : _int64_feature(cat)
}))
writer.write(instance.SerializeToString())
step+=1; print(step, np.count_nonzero(fc7_val)/(np.shape(fc7_val)[0]))
except Exception, e:
coord.request_stop(e)
finally :
coord.request_stop()
coord.join(threads)
writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment