Last active
May 15, 2020 11:58
-
-
Save Kanazawanaoaki/53050022fc1a7cfaba87a9c5bb862889 to your computer and use it in GitHub Desktop.
tensorflow/models/research/scripts/create_tf_record.py-rewrite 20200515-191822-dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import hashlib | |
import io | |
import logging | |
import numpy as np | |
import os | |
import PIL.Image | |
import tensorflow as tf | |
flags = tf.app.flags | |
flags.DEFINE_string('data_dir', '', 'Root directory to raw dataset.') | |
flags.DEFINE_string('output_dir', '', 'Dir to output TFRecord') | |
FLAGS = flags.FLAGS | |
def int64_feature(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
def int64_list_feature(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) | |
def bytes_feature(value): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def bytes_list_feature(value): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) | |
def float_list_feature(value): | |
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) | |
def get_tf_example( | |
img_path, class_label_path, instance_label_path, class_names): | |
# image | |
with tf.gfile.GFile(img_path, 'rb') as fid: | |
encoded_jpg = fid.read() | |
encoded_jpg_io = io.BytesIO(encoded_jpg) | |
image = PIL.Image.open(encoded_jpg_io) | |
if image.format != 'JPEG': | |
raise ValueError('Image format not JPEG') | |
key = hashlib.sha256(encoded_jpg).hexdigest() | |
width = int(image.width) | |
height = int(image.height) | |
filename = img_path.split('/')[-1] | |
class_label = np.load(class_label_path) | |
instance_label = np.load(instance_label_path) | |
xmin = [] | |
ymin = [] | |
xmax = [] | |
ymax = [] | |
classes = [] | |
classes_text = [] | |
R = np.max(instance_label) | |
print(instance_label_path) | |
print(R) | |
# print(instance_label) | |
for inst_lbl in range(R): | |
if inst_lbl == 0: | |
continue | |
inst_mask = instance_label == inst_lbl | |
# print(class_label_path) | |
print(class_label[inst_mask]) | |
cls_lbl = np.argmax(np.bincount(class_label[inst_mask])) | |
classes.append(cls_lbl) | |
classes_text.append(class_names[cls_lbl].encode('utf8')) | |
yind, xind = np.where(inst_mask) | |
xmin.append(float(xind.min() / float(width))) | |
ymin.append(float(yind.min() / float(height))) | |
xmax.append(float(xind.max() / float(width))) | |
ymax.append(float(yind.max() / float(height))) | |
example = tf.train.Example(features=tf.train.Features(feature={ | |
'image/height': int64_feature(height), | |
'image/width': int64_feature(width), | |
'image/filename': bytes_feature(filename.encode('utf8')), | |
'image/source_id': bytes_feature(filename.encode('utf8')), | |
'image/key/sha256': bytes_feature(key.encode('utf8')), | |
'image/encoded': bytes_feature(encoded_jpg), | |
'image/format': bytes_feature('jpeg'.encode('utf8')), | |
'image/object/bbox/xmin': float_list_feature(xmin), | |
'image/object/bbox/xmax': float_list_feature(xmax), | |
'image/object/bbox/ymin': float_list_feature(ymin), | |
'image/object/bbox/ymax': float_list_feature(ymax), | |
'image/object/class/text': bytes_list_feature(classes_text), | |
'image/object/class/label': int64_list_feature(classes), | |
})) | |
return example | |
def create_tf_record(root_dir, output_path): | |
imgs_dir = os.path.join(root_dir, 'JPEGImages') | |
class_labels_dir = os.path.join(root_dir, 'SegmentationClass') | |
instance_labels_dir = os.path.join(root_dir, 'SegmentationObject') | |
class_names_path = os.path.join(root_dir, 'class_names.txt') | |
with open(class_names_path, 'r') as f: | |
class_names = f.readlines() | |
class_names = [name.rstrip() for name in class_names] | |
img_paths = [] | |
class_label_paths = [] | |
instance_label_paths = [] | |
imgs_dir = os.path.join(root_dir, 'JPEGImages') | |
class_labels_dir = os.path.join(root_dir, 'SegmentationClass') | |
instance_labels_dir = os.path.join(root_dir, 'SegmentationObject') | |
for img_name in sorted(os.listdir(imgs_dir)): | |
img_path = os.path.join(imgs_dir, img_name) | |
basename = img_name.rstrip('.jpg') | |
class_label_path = os.path.join( | |
class_labels_dir, basename + '.npy') | |
instance_label_path = os.path.join( | |
instance_labels_dir, basename + '.npy') | |
img_paths.append(img_path) | |
class_label_paths.append(class_label_path) | |
instance_label_paths.append(instance_label_path) | |
writer = tf.python_io.TFRecordWriter(output_path) | |
print('Reading from 73B2 kitchen dataset.') | |
logging.info('Reading from 73B2 kitchen dataset.') | |
for i, (img_path, class_label_path, instance_label_path) in enumerate( | |
zip(img_paths, class_label_paths, instance_label_paths)): | |
if i % 100 == 0: | |
print('On image {} of {}'.format(i, len(img_paths))) | |
logging.info('On image {} of {}'.format(i, len(img_paths))) | |
# print(i) | |
# print(img_path) | |
tf_example = get_tf_example( | |
img_path, class_label_path, instance_label_path, class_names) | |
writer.write(tf_example.SerializeToString()) | |
writer.close() | |
def main(_): | |
data_dir = FLAGS.data_dir | |
for set_name in ['train', 'test']: | |
root_dir = os.path.join(data_dir, set_name) | |
output_path = os.path.join( | |
FLAGS.output_dir, 'kitchen_dataset_{}.record'.format(set_name)) | |
create_tf_record(root_dir, output_path) | |
if __name__ == '__main__': | |
tf.app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Reading from 73B2 kitchen dataset. | |
On image 0 of 34 | |
/tensorflow/models/research/20200515-191822-dataset/train/SegmentationObject/train_01.npy | |
1 | |
/tensorflow/models/research/20200515-191822-dataset/train/SegmentationObject/train_02.npy | |
2 | |
[5 5 5 ... 5 5 5] | |
/tensorflow/models/research/20200515-191822-dataset/train/SegmentationObject/train_03.npy | |
2 | |
[4 4 4 ... 4 4 4] | |
/tensorflow/models/research/20200515-191822-dataset/train/SegmentationObject/train_04.npy | |
10 | |
[2 2 2 ... 2 2 2] | |
[2 2 2 ... 2 2 2] | |
[2 2 2 ... 2 2 2] | |
[3 3 3 ... 3 3 3] | |
[3 3 3 ... 3 3 3] | |
[3 3 3 ... 3 3 3] | |
[] | |
Traceback (most recent call last): | |
File "create_tf_record.py", line 156, in <module> | |
tf.app.run() | |
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 125, in run | |
_sys.exit(main(argv)) | |
File "create_tf_record.py", line 152, in main | |
create_tf_record(root_dir, output_path) | |
File "create_tf_record.py", line 141, in create_tf_record | |
img_path, class_label_path, instance_label_path, class_names) | |
File "create_tf_record.py", line 77, in get_tf_example | |
cls_lbl = np.argmax(np.bincount(class_label[inst_mask])) | |
File "/usr/local/lib/python2.7/dist-packages/numpy/core/fromnumeric.py", line 1037, in argmax | |
return _wrapfunc(a, 'argmax', axis=axis, out=out) | |
File "/usr/local/lib/python2.7/dist-packages/numpy/core/fromnumeric.py", line 51, in _wrapfunc | |
return getattr(obj, method)(*args, **kwds) | |
ValueError: attempt to get argmax of an empty sequence |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment