Skip to content

Instantly share code, notes, and snippets.

@qfgaohao
Created October 27, 2017 09:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save qfgaohao/51556faa527fba89a81d048dda37c504 to your computer and use it in GitHub Desktop.
Save qfgaohao/51556faa527fba89a81d048dda37c504 to your computer and use it in GitHub Desktop.
Convert an ImageNet like dataset into tfRecord files, provide a method get_dataset to read the created files. It has similar functions as ImageFolder in Pytorch. Modified from https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/download_and_convert_flowers.py https://github.com/tensorflow/models/blob/master/research…
r"""Convert an ImageNet like dataset into tfRecord files, provide a method get_dataset to read the created files.
It has similar functions as ImageFolder in Pytorch.
Modified from
https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/download_and_convert_flowers.py
https://github.com/tensorflow/models/blob/master/research/slim/datasets/flowers.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import random
import sys
import argparse
from PIL import Image
import tensorflow as tf
# use your location of the code folder 'tensorflow/models/tree/master/research/slim'
sys.path.append(r"/Users/hao/data/slim_models/models/research/slim")
from datasets import dataset_utils
from tensorflow.contrib import slim
# Seed for repeatability.
_RANDOM_SEED = 0
def get_dataset(dataset_dir, dataset_type):
"""Create a dataset from tfRecord files.
dataset_type is used to specify train, test and validation data.
For example, if the dataset_type is train, files satisfying the
pattern "train_*.tfrecord" in dataset_dir will be treated as
data sources for this dataset.
An example of using it:
with tf.Graph().as_default():
dataset = get_dataset(dataset_dir, dataset_type)
data_provider = slim.dataset_data_provider.DatasetDataProvider(
dataset, common_queue_capacity=32, common_queue_min=1)
image, label = data_provider.get(['image', 'label'])
with tf.Session() as sess:
with slim.queues.QueueRunners(sess):
for i in range(4):
np_image, np_label = sess.run([image, label])
height, width, _ = np_image.shape
class_name = name = dataset.labels_to_names[np_label]
plt.figure()
plt.imshow(np_image)
plt.title('%s, %d x %d' % (name, height, width))
plt.axis('off')
plt.show()
"""
file_pattern = os.path.join(dataset_dir, '{}_*.tfrecord'.format(dataset_type))
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
with open(os.path.join(dataset_dir, dataset_type + '.num')) as f:
num_samples = int(f.read().strip())
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=num_samples,
items_to_descriptions=None,
num_classes=len(labels_to_names),
labels_to_names=labels_to_names)
def get_image_size(image_file):
img = Image.open(image_file)
return img.size
def get_image_type(image_file):
file_name = os.path.basename(image_file).lower()
pos = file_name.find('.')
return file_name[pos + 1:].encode(encoding="ascii")
def _get_filenames_and_classes(image_dir_root):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir: A directory containing a set of subdirectories representing
class names. Each subdirectory should contain PNG or JPG encoded images.
Returns:
A list of image file paths, relative to `dataset_dir` and the list of
subdirectories, representing class names.
"""
directories = []
class_names = []
for filename in os.listdir(image_dir_root):
path = os.path.join(image_dir_root, filename)
if os.path.isdir(path):
directories.append(path)
class_names.append(filename)
photo_filenames = []
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
photo_filenames.append(path)
return photo_filenames, sorted(class_names)
def _get_dataset_filename(dataset_dir, split_name, shard_id, shard_num):
output_filename = '%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, shard_num)
return os.path.join(dataset_dir, output_filename)
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir, shard_num):
"""Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'validation'.
filenames: A list of absolute paths to png or jpg images.
class_names_to_ids: A dictionary from class names (strings) to ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
num_per_shard = int(math.ceil(len(filenames) / float(shard_num)))
with tf.Graph().as_default():
with tf.Session('') as sess:
for shard_id in range(shard_num):
output_filename = _get_dataset_filename(
dataset_dir, split_name, shard_id, shard_num)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id + 1) * num_per_shard, len(filenames))
for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i + 1, len(filenames), shard_id))
sys.stdout.flush()
# Read the filename:
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
height, width = get_image_size(filenames[i])
image_type = get_image_type(filenames[i])
class_name = os.path.basename(os.path.dirname(filenames[i]))
class_id = class_names_to_ids[class_name]
example = dataset_utils.image_to_tfexample(
image_data, image_type, height, width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def convert(image_dir, sub_dir, dataset_dir, shard_num, class_names_to_ids):
"""Conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
filenames, class_names = _get_filenames_and_classes(os.path.join(image_dir, sub_dir))
save_labels = False
if not class_names_to_ids:
class_names_to_ids = dict(zip(class_names, range(len(class_names))))
save_labels = True
# Divide into train and test:
random.seed(_RANDOM_SEED)
random.shuffle(filenames)
# First, convert the training and validation sets.
_convert_dataset(sub_dir, filenames, class_names_to_ids, dataset_dir, shard_num)
if save_labels:
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
with open(os.path.join(dataset_dir, sub_dir + ".num"), 'w') as f:
f.write(str(len(filenames)))
return class_names_to_ids
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir',
type=str,
required=True,
help='Directory for images.')
parser.add_argument('--dataset_dir',
type=str,
required=True,
help='Directory for storing the parsed tfRecord files and the label file.')
parser.add_argument('--sub_dirs',
nargs='+',
help='the sub directory of image_dir, used for specifying train, val and test data.',
required=True)
parser.add_argument('--shard_num',
type=int,
default=1,
help='The number of shards.')
args, unparsed = parser.parse_known_args()
if unparsed:
parser.print_help()
sys.exit(1)
dataset_dir = args.dataset_dir
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
class_names_to_ids = None
for sub_dir in args.sub_dirs:
class_names_to_ids = convert(args.image_dir, sub_dir, dataset_dir, args.shard_num, class_names_to_ids)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment